diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json new file mode 100644 index 0000000..b5c2efd --- /dev/null +++ b/.devcontainer/devcontainer.json @@ -0,0 +1,17 @@ +{ + "name": "pallets/werkzeug", + "image": "mcr.microsoft.com/devcontainers/python:3", + "customizations": { + "vscode": { + "settings": { + "python.defaultInterpreterPath": "${workspaceFolder}/.venv", + "python.terminal.activateEnvInCurrentTerminal": true, + "python.terminal.launchArgs": [ + "-X", + "dev" + ] + } + } + }, + "onCreateCommand": ".devcontainer/on-create-command.sh" +} diff --git a/.devcontainer/on-create-command.sh b/.devcontainer/on-create-command.sh new file mode 100755 index 0000000..fdf7795 --- /dev/null +++ b/.devcontainer/on-create-command.sh @@ -0,0 +1,9 @@ +#!/bin/bash +set -e + +python3 -m venv .venv +. .venv/bin/activate +pip install -U pip +pip install -r requirements/dev.txt +pip install -e . +pre-commit install --install-hooks diff --git a/.flake8 b/.flake8 new file mode 100644 index 0000000..6ac59c8 --- /dev/null +++ b/.flake8 @@ -0,0 +1,29 @@ +[flake8] +extend-select = + # bugbear + B + # bugbear opinions + B9 + # implicit str concat + ISC +extend-ignore = + # slice notation whitespace, invalid + E203 + # import at top, too many circular import fixes + E402 + # line length, handled by bugbear B950 + E501 + # bare except, handled by bugbear B001 + E722 + # zip with strict=, requires python >= 3.10 + B905 + # string formatting opinion, B028 renamed to B907 + B028 + B907 +# up to 88 allowed by bugbear B950 +max-line-length = 80 +per-file-ignores = + # __init__ exports names + **/__init__.py: F401 + # LocalProxy assigns lambdas + src/werkzeug/local.py: E731 diff --git a/.github/workflows/lock.yaml b/.github/workflows/lock.yaml index b4f7633..e962fd0 100644 --- a/.github/workflows/lock.yaml +++ b/.github/workflows/lock.yaml @@ -1,15 +1,25 @@ name: 'Lock threads' +# Lock closed issues that have not received any further activity for +# two weeks. This does not close open issues, only humans may do that. +# We find that it is easier to respond to new issues with fresh examples +# rather than continuing discussions on old issues. on: schedule: - cron: '0 0 * * *' +permissions: + issues: write + pull-requests: write + +concurrency: + group: lock + jobs: lock: runs-on: ubuntu-latest steps: - - uses: dessant/lock-threads@v3 + - uses: dessant/lock-threads@be8aa5be94131386884a6da4189effda9b14aa21 with: - github-token: ${{ github.token }} issue-inactive-days: 14 pr-inactive-days: 14 diff --git a/.github/workflows/publish.yaml b/.github/workflows/publish.yaml new file mode 100644 index 0000000..05681f5 --- /dev/null +++ b/.github/workflows/publish.yaml @@ -0,0 +1,72 @@ +name: Publish +on: + push: + tags: + - '*' +jobs: + build: + runs-on: ubuntu-latest + outputs: + hash: ${{ steps.hash.outputs.hash }} + steps: + - uses: actions/checkout@f43a0e5ff2bd294095638e18286ca9a3d1956744 + - uses: actions/setup-python@61a6322f88396a6271a6ee3565807d608ecaddd1 + with: + python-version: '3.x' + cache: 'pip' + cache-dependency-path: 'requirements/*.txt' + - run: pip install -r requirements/build.txt + # Use the commit date instead of the current date during the build. + - run: echo "SOURCE_DATE_EPOCH=$(git log -1 --pretty=%ct)" >> $GITHUB_ENV + - run: python -m build + # Generate hashes used for provenance. + - name: generate hash + id: hash + run: cd dist && echo "hash=$(sha256sum * | base64 -w0)" >> $GITHUB_OUTPUT + - uses: actions/upload-artifact@0b7f8abb1508181956e8e162db84b466c27e18ce + with: + path: ./dist + provenance: + needs: ['build'] + permissions: + actions: read + id-token: write + contents: write + # Can't pin with hash due to how this workflow works. + uses: slsa-framework/slsa-github-generator/.github/workflows/generator_generic_slsa3.yml@v1.9.0 + with: + base64-subjects: ${{ needs.build.outputs.hash }} + create-release: + # Upload the sdist, wheels, and provenance to a GitHub release. They remain + # available as build artifacts for a while as well. + needs: ['provenance'] + runs-on: ubuntu-latest + permissions: + contents: write + steps: + - uses: actions/download-artifact@9bc31d5ccc31df68ecc42ccf4149144866c47d8a + - name: create release + run: > + gh release create --draft --repo ${{ github.repository }} + ${{ github.ref_name }} + *.intoto.jsonl/* artifact/* + env: + GH_TOKEN: ${{ github.token }} + publish-pypi: + needs: ['provenance'] + # Wait for approval before attempting to upload to PyPI. This allows reviewing the + # files in the draft release. + environment: 'publish' + runs-on: ubuntu-latest + permissions: + id-token: write + steps: + - uses: actions/download-artifact@9bc31d5ccc31df68ecc42ccf4149144866c47d8a + # Try uploading to Test PyPI first, in case something fails. + - uses: pypa/gh-action-pypi-publish@b7f401de30cb6434a1e19f805ff006643653240e + with: + repository-url: https://test.pypi.org/legacy/ + packages-dir: artifact/ + - uses: pypa/gh-action-pypi-publish@b7f401de30cb6434a1e19f805ff006643653240e + with: + packages-dir: artifact/ diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index d4441ff..c1e6ea3 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -24,32 +24,27 @@ jobs: fail-fast: false matrix: include: - - {name: Linux, python: '3.10', os: ubuntu-latest, tox: py310} - - {name: Windows, python: '3.10', os: windows-latest, tox: py310} - - {name: Mac, python: '3.10', os: macos-latest, tox: py310} - - {name: '3.11-dev', python: '3.11-dev', os: ubuntu-latest, tox: py311} + - {name: Linux, python: '3.11', os: ubuntu-latest, tox: py311} + - {name: Windows, python: '3.11', os: windows-latest, tox: py311} + - {name: Mac, python: '3.11', os: macos-latest, tox: py311} + - {name: '3.12-dev', python: '3.12-dev', os: ubuntu-latest, tox: py312} + - {name: '3.10', python: '3.10', os: ubuntu-latest, tox: py310} - {name: '3.9', python: '3.9', os: ubuntu-latest, tox: py39} - {name: '3.8', python: '3.8', os: ubuntu-latest, tox: py38} - - {name: '3.7', python: '3.7', os: ubuntu-latest, tox: py37} - - {name: 'PyPy', python: 'pypy-3.7', os: ubuntu-latest, tox: pypy37} - - {name: Typing, python: '3.10', os: ubuntu-latest, tox: typing} + - {name: 'PyPy', python: 'pypy-3.10', os: ubuntu-latest, tox: pypy310} + - {name: Typing, python: '3.11', os: ubuntu-latest, tox: typing} steps: - - uses: actions/checkout@v3 - - uses: actions/setup-python@v4 + - uses: actions/checkout@f43a0e5ff2bd294095638e18286ca9a3d1956744 + - uses: actions/setup-python@61a6322f88396a6271a6ee3565807d608ecaddd1 with: python-version: ${{ matrix.python }} cache: 'pip' cache-dependency-path: 'requirements/*.txt' - - name: update pip - run: | - pip install -U wheel - pip install -U setuptools - python -m pip install -U pip - name: cache mypy - uses: actions/cache@v3.0.4 + uses: actions/cache@88522ab9f39a2ea568f7027eddc7d8d8bc9d59c8 with: path: ./.mypy_cache - key: mypy|${{ matrix.python }}|${{ hashFiles('setup.cfg') }} + key: mypy|${{ matrix.python }}|${{ hashFiles('pyproject.toml') }} if: matrix.tox == 'typing' - run: pip install tox - - run: tox -e ${{ matrix.tox }} + - run: tox run -e ${{ matrix.tox }} diff --git a/.gitignore b/.gitignore index 36f3670..aecea1a 100644 --- a/.gitignore +++ b/.gitignore @@ -4,7 +4,7 @@ dist /src/Werkzeug.egg-info *.pyc *.pyo -env +.venv .DS_Store docs/_build bench/a diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 55f8c13..6425015 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,42 +1,40 @@ ci: - autoupdate_branch: "2.2.x" + autoupdate_branch: "2.3.x" autoupdate_schedule: monthly repos: - repo: https://github.com/asottile/pyupgrade - rev: v2.37.3 + rev: v3.10.1 hooks: - id: pyupgrade - args: ["--py37-plus"] - - repo: https://github.com/asottile/reorder_python_imports - rev: v3.8.2 + args: ["--py38-plus"] + - repo: https://github.com/asottile/reorder-python-imports + rev: v3.10.0 hooks: - id: reorder-python-imports name: Reorder Python imports (src, tests) files: "^(?!examples/)" args: ["--application-directories", ".:src"] - additional_dependencies: ["setuptools>60.9"] - id: reorder-python-imports name: Reorder Python imports (examples) files: "^examples/" args: ["--application-directories", "examples"] - additional_dependencies: ["setuptools>60.9"] - repo: https://github.com/psf/black - rev: 22.6.0 + rev: 23.7.0 hooks: - id: black - repo: https://github.com/PyCQA/flake8 - rev: 5.0.4 + rev: 6.1.0 hooks: - id: flake8 additional_dependencies: - flake8-bugbear - flake8-implicit-str-concat - repo: https://github.com/peterdemin/pip-compile-multi - rev: v2.4.6 + rev: v2.6.3 hooks: - id: pip-compile-multi-verify - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.3.0 + rev: v4.4.0 hooks: - id: fix-byte-order-marker - id: trailing-whitespace diff --git a/CHANGES.rst b/CHANGES.rst index 18e68af..6f801b9 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -1,5 +1,276 @@ .. currentmodule:: werkzeug +Version 3.0.1 +------------- + +Released 2023-10-24 + +- Fix slow multipart parsing for large parts potentially enabling DoS + attacks. :cwe:`CWE-407` + +Version 3.0.0 +------------- + +Released 2023-09-30 + +- Remove previously deprecated code. :pr:`2768` +- Deprecate the ``__version__`` attribute. Use feature detection, or + ``importlib.metadata.version("werkzeug")``, instead. :issue:`2770` +- ``generate_password_hash`` uses scrypt by default. :issue:`2769` +- Add the ``"werkzeug.profiler"`` item to the WSGI ``environ`` dictionary + passed to `ProfilerMiddleware`'s `filename_format` function. It contains + the ``elapsed`` and ``time`` values for the profiled request. :issue:`2775` +- Explicitly marked the PathConverter as non path isolating. :pr:`2784` + + +Version 2.3.8 +------------- + +Unreleased + + +Version 2.3.7 +------------- + +Released 2023-08-14 + +- Use ``flit_core`` instead of ``setuptools`` as build backend. +- Fix parsing of multipart bodies. :issue:`2734` Adjust index of last newline + in data start. :issue:`2761` +- Parsing ints from header values strips spacing first. :issue:`2734` +- Fix empty file streaming when testing. :issue:`2740` +- Clearer error message when URL rule does not start with slash. :pr:`2750` +- ``Accept`` ``q`` value can be a float without a decimal part. :issue:`2751` + + +Version 2.3.6 +------------- + +Released 2023-06-08 + +- ``FileStorage.content_length`` does not fail if the form data did not provide a + value. :issue:`2726` + + +Version 2.3.5 +------------- + +Released 2023-06-07 + +- Python 3.12 compatibility. :issue:`2704` +- Fix handling of invalid base64 values in ``Authorization.from_header``. :issue:`2717` +- The debugger escapes the exception message in the page title. :pr:`2719` +- When binding ``routing.Map``, a long IDNA ``server_name`` with a port does not fail + encoding. :issue:`2700` +- ``iri_to_uri`` shows a deprecation warning instead of an error when passing bytes. + :issue:`2708` +- When parsing numbers in HTTP request headers such as ``Content-Length``, only ASCII + digits are accepted rather than any format that Python's ``int`` and ``float`` + accept. :issue:`2716` + + +Version 2.3.4 +------------- + +Released 2023-05-08 + +- ``Authorization.from_header`` and ``WWWAuthenticate.from_header`` detects tokens + that end with base64 padding (``=``). :issue:`2685` +- Remove usage of ``warnings.catch_warnings``. :issue:`2690` +- Remove ``max_form_parts`` restriction from standard form data parsing and only use + if for multipart content. :pr:`2694` +- ``Response`` will avoid converting the ``Location`` header in some cases to preserve + invalid URL schemes like ``itms-services``. :issue:`2691` + + +Version 2.3.3 +------------- + +Released 2023-05-01 + +- Fix parsing of large multipart bodies. Remove invalid leading newline, and restore + parsing speed. :issue:`2658, 2675` +- The cookie ``Path`` attribute is set to ``/`` by default again, to prevent clients + from falling back to RFC 6265's ``default-path`` behavior. :issue:`2672, 2679` + + +Version 2.3.2 +------------- + +Released 2023-04-28 + +- Parse the cookie ``Expires`` attribute correctly in the test client. :issue:`2669` +- ``max_content_length`` can only be enforced on streaming requests if the server + sets ``wsgi.input_terminated``. :issue:`2668` + + +Version 2.3.1 +------------- + +Released 2023-04-27 + +- Percent-encode plus (+) when building URLs and in test requests. :issue:`2657` +- Cookie values don't quote characters defined in RFC 6265. :issue:`2659` +- Include ``pyi`` files for ``datastructures`` type annotations. :issue:`2660` +- ``Authorization`` and ``WWWAuthenticate`` objects can be compared for equality. + :issue:`2665` + + +Version 2.3.0 +------------- + +Released 2023-04-25 + +- Drop support for Python 3.7. :pr:`2648` +- Remove previously deprecated code. :pr:`2592` +- Passing bytes where strings are expected is deprecated, as well as the ``charset`` + and ``errors`` parameters in many places. Anywhere that was annotated, documented, + or tested to accept bytes shows a warning. Removing this artifact of the transition + from Python 2 to 3 removes a significant amount of overhead in instance checks and + encoding cycles. In general, always work with UTF-8, the modern HTML, URL, and HTTP + standards all strongly recommend this. :issue:`2602` +- Deprecate the ``werkzeug.urls`` module, except for the ``uri_to_iri`` and + ``iri_to_uri`` functions. Use the ``urllib.parse`` library instead. :issue:`2600` +- Update which characters are considered safe when using percent encoding in URLs, + based on the WhatWG URL Standard. :issue:`2601` +- Update which characters are considered safe when using percent encoding for Unicode + filenames in downloads. :issue:`2598` +- Deprecate the ``safe_conversion`` parameter of ``iri_to_uri``. The ``Location`` + header is converted to IRI using the same process as everywhere else. :issue:`2609` +- Deprecate ``werkzeug.wsgi.make_line_iter`` and ``make_chunk_iter``. :pr:`2613` +- Use modern packaging metadata with ``pyproject.toml`` instead of ``setup.cfg``. + :pr:`2574` +- ``Request.get_json()`` will raise a ``415 Unsupported Media Type`` error if the + ``Content-Type`` header is not ``application/json``, instead of a generic 400. + :issue:`2550` +- A URL converter's ``part_isolating`` defaults to ``False`` if its ``regex`` contains + a ``/``. :issue:`2582` +- A custom converter's regex can have capturing groups without breaking the router. + :pr:`2596` +- The reloader can pick up arguments to ``python`` like ``-X dev``, and does not + require heuristics to determine how to reload the command. Only available + on Python >= 3.10. :issue:`2589` +- The Watchdog reloader ignores file opened events. Bump the minimum version of + Watchdog to 2.3.0. :issue:`2603` +- When using a Unix socket for the development server, the path can start with a dot. + :issue:`2595` +- Increase default work factor for PBKDF2 to 600,000 iterations. :issue:`2611` +- ``parse_options_header`` is 2-3 times faster. It conforms to :rfc:`9110`, some + invalid parts that were previously accepted are now ignored. :issue:`1628` +- The ``is_filename`` parameter to ``unquote_header_value`` is deprecated. :pr:`2614` +- Deprecate the ``extra_chars`` parameter and passing bytes to ``quote_header_value``, + the ``allow_token`` parameter to ``dump_header``, and the ``cls`` parameter and + passing bytes to ``parse_dict_header``. :pr:`2618` +- Improve ``parse_accept_header`` implementation. Parse according to :rfc:`9110`. + Discard items with invalid ``q`` values. :issue:`1623` +- ``quote_header_value`` quotes the empty string. :pr:`2618` +- ``dump_options_header`` skips ``None`` values rather than using a bare key. + :pr:`2618` +- ``dump_header`` and ``dump_options_header`` will not quote a value if the key ends + with an asterisk ``*``. +- ``parse_dict_header`` will decode values with charsets. :pr:`2618` +- Refactor the ``Authorization`` and ``WWWAuthenticate`` header data structures. + :issue:`1769`, :pr:`2619` + + - Both classes have ``type``, ``parameters``, and ``token`` attributes. The + ``token`` attribute supports auth schemes that use a single opaque token rather + than ``key=value`` parameters, such as ``Bearer``. + - Neither class is a ``dict`` anymore, although they still implement getting, + setting, and deleting ``auth[key]`` and ``auth.key`` syntax, as well as + ``auth.get(key)`` and ``key in auth``. + - Both classes have a ``from_header`` class method. ``parse_authorization_header`` + and ``parse_www_authenticate_header`` are deprecated. + - The methods ``WWWAuthenticate.set_basic`` and ``set_digest`` are deprecated. + Instead, an instance should be created and assigned to + ``response.www_authenticate``. + - A list of instances can be assigned to ``response.www_authenticate`` to set + multiple header values. However, accessing the property only returns the first + instance. + +- Refactor ``parse_cookie`` and ``dump_cookie``. :pr:`2637` + + - ``parse_cookie`` is up to 40% faster, ``dump_cookie`` is up to 60% faster. + - Passing bytes to ``parse_cookie`` and ``dump_cookie`` is deprecated. The + ``dump_cookie`` ``charset`` parameter is deprecated. + - ``dump_cookie`` allows ``domain`` values that do not include a dot ``.``, and + strips off a leading dot. + - ``dump_cookie`` does not set ``path="/"`` unnecessarily by default. + +- Refactor the test client cookie implementation. :issue:`1060, 1680` + + - The ``cookie_jar`` attribute is deprecated. ``http.cookiejar`` is no longer used + for storage. + - Domain and path matching is used when sending cookies in requests. The + ``domain`` and ``path`` parameters default to ``localhost`` and ``/``. + - Added a ``get_cookie`` method to inspect cookies. + - Cookies have ``decoded_key`` and ``decoded_value`` attributes to match what the + app sees rather than the encoded values a client would see. + - The first positional ``server_name`` parameter to ``set_cookie`` and + ``delete_cookie`` is deprecated. Use the ``domain`` parameter instead. + - Other parameters to ``delete_cookie`` besides ``domain``, ``path``, and + ``value`` are deprecated. + +- If ``request.max_content_length`` is set, it is checked immediately when accessing + the stream, and while reading from the stream in general, rather than only during + form parsing. :issue:`1513` +- The development server, which must not be used in production, will exhaust the + request stream up to 10GB or 1000 reads. This allows clients to see a 413 error if + ``max_content_length`` is exceeded, instead of a "connection reset" failure. + :pr:`2620` +- The development server discards header keys that contain underscores ``_``, as they + are ambiguous with dashes ``-`` in WSGI. :pr:`2622` +- ``secure_filename`` looks for more Windows reserved file names. :pr:`2623` +- Update type annotation for ``best_match`` to make ``default`` parameter clearer. + :issue:`2625` +- Multipart parser handles empty fields correctly. :issue:`2632` +- The ``Map`` ``charset`` parameter and ``Request.url_charset`` property are + deprecated. Percent encoding in URLs must always represent UTF-8 bytes. Invalid + bytes are left percent encoded rather than replaced. :issue:`2602` +- The ``Request.charset``, ``Request.encoding_errors``, ``Response.charset``, and + ``Client.charset`` attributes are deprecated. Request and response data must always + use UTF-8. :issue:`2602` +- Header values that have charset information only allow ASCII, UTF-8, and ISO-8859-1. + :pr:`2614, 2641` +- Update type annotation for ``ProfilerMiddleware`` ``stream`` parameter. + :issue:`2642` +- Use postponed evaluation of annotations. :pr:`2644` +- The development server escapes ASCII control characters in decoded URLs before + logging the request to the terminal. :pr:`2652` +- The ``FormDataParser`` ``parse_functions`` attribute and ``get_parse_func`` method, + and the invalid ``application/x-url-encoded`` content type, are deprecated. + :pr:`2653` +- ``generate_password_hash`` supports scrypt. Plain hash methods are deprecated, only + scrypt and pbkdf2 are supported. :issue:`2654` + + +Version 2.2.3 +------------- + +Released 2023-02-14 + +- Ensure that URL rules using path converters will redirect with strict slashes when + the trailing slash is missing. :issue:`2533` +- Type signature for ``get_json`` specifies that return type is not optional when + ``silent=False``. :issue:`2508` +- ``parse_content_range_header`` returns ``None`` for a value like ``bytes */-1`` + where the length is invalid, instead of raising an ``AssertionError``. :issue:`2531` +- Address remaining ``ResourceWarning`` related to the socket used by ``run_simple``. + Remove ``prepare_socket``, which now happens when creating the server. :issue:`2421` +- Update pre-existing headers for ``multipart/form-data`` requests with the test + client. :issue:`2549` +- Fix handling of header extended parameters such that they are no longer quoted. + :issue:`2529` +- ``LimitedStream.read`` works correctly when wrapping a stream that may not return + the requested size in one ``read`` call. :issue:`2558` +- A cookie header that starts with ``=`` is treated as an empty key and discarded, + rather than stripping the leading ``==``. +- Specify a maximum number of multipart parts, default 1000, after which a + ``RequestEntityTooLarge`` exception is raised on parsing. This mitigates a DoS + attack where a larger number of form/file parts would result in disproportionate + resource use. + + + Version 2.2.2 ------------- @@ -23,6 +294,7 @@ Released 2022-08-08 ``run_simple``. :issue:`2421` + Version 2.2.1 ------------- @@ -54,8 +326,9 @@ Released 2022-07-23 debug console. :pr:`2439` - Fix compatibility with Python 3.11 by ensuring that ``end_lineno`` and ``end_col_offset`` are present on AST nodes. :issue:`2425` -- Add a new faster matching router based on a state - machine. :pr:`2433` +- Add a new faster URL matching router based on a state machine. If a custom converter + needs to match a ``/`` it must set the class variable ``part_isolating = False``. + :pr:`2433` - Fix branch leaf path masking branch paths when strict-slashes is disabled. :issue:`1074` - Names within options headers are always converted to lowercase. This @@ -775,7 +1048,7 @@ Released 2019-03-19 (:pr:`1358`) - :func:`http.parse_cookie` ignores empty segments rather than producing a cookie with no key or value. (:issue:`1245`, :pr:`1301`) -- :func:`~http.parse_authorization_header` (and +- ``http.parse_authorization_header`` (and :class:`~datastructures.Authorization`, :attr:`~wrappers.Request.authorization`) treats the authorization header as UTF-8. On Python 2, basic auth username and password are @@ -1540,8 +1813,8 @@ Version 0.9.2 (bugfix release, released on July 18th 2013) -- Added `unsafe` parameter to :func:`~werkzeug.urls.url_quote`. -- Fixed an issue with :func:`~werkzeug.urls.url_quote_plus` not quoting +- Added ``unsafe`` parameter to ``urls.url_quote``. +- Fixed an issue with ``urls.url_quote_plus`` not quoting `'+'` correctly. - Ported remaining parts of :class:`~werkzeug.contrib.RedisCache` to Python 3.3. @@ -1590,9 +1863,8 @@ Released on June 13nd 2013, codename Planierraupe. certificates easily and load them from files. - Refactored test client to invoke the open method on the class for redirects. This makes subclassing more powerful. -- :func:`werkzeug.wsgi.make_chunk_iter` and - :func:`werkzeug.wsgi.make_line_iter` now support processing of - iterators and streams. +- ``wsgi.make_chunk_iter`` and ``make_line_iter`` now support processing + of iterators and streams. - URL generation by the routing system now no longer quotes ``+``. - URL fixing now no longer quotes certain reserved characters. @@ -1690,7 +1962,7 @@ Version 0.8.3 (bugfix release, released on February 5th 2012) -- Fixed another issue with :func:`werkzeug.wsgi.make_line_iter` +- Fixed another issue with ``wsgi.make_line_iter`` where lines longer than the buffer size were not handled properly. - Restore stdout after debug console finished executing so @@ -1758,7 +2030,7 @@ Released on September 29th 2011, codename Lötkolben - Werkzeug now uses a new method to check that the length of incoming data is complete and will raise IO errors by itself if the server fails to do so. -- :func:`~werkzeug.wsgi.make_line_iter` now requires a limit that is +- ``wsgi.make_line_iter`` now requires a limit that is not higher than the length the stream can provide. - Refactored form parsing into a form parser class that makes it possible to hook into individual parts of the parsing process for debugging and @@ -1958,7 +2230,7 @@ Released on Feb 19th 2010, codename Hammer. - the form data parser will now look at the filename instead the content type to figure out if it should treat the upload as regular form data or file upload. This fixes a bug with Google Chrome. -- improved performance of `make_line_iter` and the multipart parser +- improved performance of ``make_line_iter`` and the multipart parser for binary uploads. - fixed :attr:`~werkzeug.BaseResponse.is_streamed` - fixed a path quoting bug in `EnvironBuilder` that caused PATH_INFO and @@ -2087,7 +2359,7 @@ Released on April 24th, codename Schlagbohrer. - added :mod:`werkzeug.contrib.lint` - added `passthrough_errors` to `run_simple`. - added `secure_filename` -- added :func:`make_line_iter` +- added ``make_line_iter`` - :class:`MultiDict` copies now instead of revealing internal lists to the caller for `getlist` and iteration functions that return lists. diff --git a/CONTRIBUTING.rst b/CONTRIBUTING.rst index 9f40800..97486de 100644 --- a/CONTRIBUTING.rst +++ b/CONTRIBUTING.rst @@ -7,19 +7,17 @@ Thank you for considering contributing to Werkzeug! Support questions ----------------- -Please don't use the issue tracker for this. The issue tracker is a -tool to address bugs and feature requests in Werkzeug itself. Use one of -the following resources for questions about using Werkzeug or issues -with your own code: +Please don't use the issue tracker for this. The issue tracker is a tool to address bugs +and feature requests in Werkzeug itself. Use one of the following resources for +questions about using Werkzeug or issues with your own code: -- The ``#get-help`` channel on our Discord chat: - https://discord.gg/pallets -- The mailing list flask@python.org for long term discussion or larger - issues. +- The ``#questions`` channel on our Discord chat: https://discord.gg/pallets - Ask on `Stack Overflow`_. Search with Google first using: ``site:stackoverflow.com werkzeug {search term, exception message, etc.}`` +- Ask on our `GitHub Discussions`_ for long term discussion or larger questions. .. _Stack Overflow: https://stackoverflow.com/questions/tagged/werkzeug?tab=Frequent +.. _GitHub Discussions: https://github.com/pallets/werkzeug/discussions Reporting issues @@ -66,9 +64,30 @@ Include the following in your patch: .. _pre-commit: https://pre-commit.com -First time setup -~~~~~~~~~~~~~~~~ +First time setup using GitHub Codespaces +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +`GitHub Codespaces`_ creates a development environment that is already set up for the +project. By default it opens in Visual Studio Code for the Web, but this can +be changed in your GitHub profile settings to use Visual Studio Code or JetBrains +PyCharm on your local computer. + +- Make sure you have a `GitHub account`_. +- From the project's repository page, click the green "Code" button and then "Create + codespace on main". +- The codespace will be set up, then Visual Studio Code will open. However, you'll + need to wait a bit longer for the Python extension to be installed. You'll know it's + ready when the terminal at the bottom shows that the virtualenv was activated. +- Check out a branch and `start coding`_. + +.. _GitHub Codespaces: https://docs.github.com/en/codespaces +.. _devcontainer: https://docs.github.com/en/codespaces/setting-up-your-project-for-codespaces/adding-a-dev-container-configuration/introduction-to-dev-containers + + +First time setup in your local environment +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +- Make sure you have a `GitHub account`_. - Download and install the `latest version of git`_. - Configure git with your `username`_ and `email`_. @@ -77,99 +96,93 @@ First time setup $ git config --global user.name 'your name' $ git config --global user.email 'your email' -- Make sure you have a `GitHub account`_. - Fork Werkzeug to your GitHub account by clicking the `Fork`_ button. -- `Clone`_ the main repository locally. +- `Clone`_ your fork locally, replacing ``your-username`` in the command below with + your actual username. .. code-block:: text - $ git clone https://github.com/pallets/werkzeug + $ git clone https://github.com/your-username/werkzeug $ cd werkzeug -- Add your fork as a remote to push your work to. Replace - ``{username}`` with your username. This names the remote "fork", the - default Pallets remote is "origin". - - .. code-block:: text - - $ git remote add fork https://github.com/{username}/werkzeug - -- Create a virtualenv. - - .. code-block:: text - - $ python3 -m venv env - $ . env/bin/activate - - On Windows, activating is different. - - .. code-block:: text - - > env\Scripts\activate - -- Upgrade pip and setuptools. - - .. code-block:: text - - $ python -m pip install --upgrade pip setuptools - -- Install the development dependencies, then install Werkzeug in - editable mode. +- Create a virtualenv. Use the latest version of Python. + + - Linux/macOS + + .. code-block:: text + + $ python3 -m venv .venv + $ . .venv/bin/activate + + - Windows + + .. code-block:: text + + > py -3 -m venv .venv + > .venv\Scripts\activate + +- Install the development dependencies, then install Werkzeug in editable mode. .. code-block:: text + $ python -m pip install -U pip $ pip install -r requirements/dev.txt && pip install -e . - Install the pre-commit hooks. .. code-block:: text - $ pre-commit install + $ pre-commit install --install-hooks +.. _GitHub account: https://github.com/join .. _latest version of git: https://git-scm.com/downloads .. _username: https://docs.github.com/en/github/using-git/setting-your-username-in-git .. _email: https://docs.github.com/en/github/setting-up-and-managing-your-github-user-account/setting-your-commit-email-address -.. _GitHub account: https://github.com/join .. _Fork: https://github.com/pallets/werkzeug/fork .. _Clone: https://docs.github.com/en/github/getting-started-with-github/fork-a-repo#step-2-create-a-local-clone-of-your-fork +.. _start coding: + Start coding ~~~~~~~~~~~~ -- Create a branch to identify the issue you would like to work on. If - you're submitting a bug or documentation fix, branch off of the - latest ".x" branch. +- Create a branch to identify the issue you would like to work on. If you're + submitting a bug or documentation fix, branch off of the latest ".x" branch. .. code-block:: text $ git fetch origin - $ git checkout -b your-branch-name origin/2.0.x + $ git checkout -b your-branch-name origin/2.2.x - If you're submitting a feature addition or change, branch off of the - "main" branch. + If you're submitting a feature addition or change, branch off of the "main" branch. .. code-block:: text $ git fetch origin $ git checkout -b your-branch-name origin/main -- Using your favorite editor, make your changes, - `committing as you go`_. -- Include tests that cover any code changes you make. Make sure the - test fails without your patch. Run the tests as described below. -- Push your commits to your fork on GitHub and - `create a pull request`_. Link to the issue being addressed with - ``fixes #123`` in the pull request. +- Using your favorite editor, make your changes, `committing as you go`_. + + - If you are in a codespace, you will be prompted to `create a fork`_ the first + time you make a commit. Enter ``Y`` to continue. + +- Include tests that cover any code changes you make. Make sure the test fails without + your patch. Run the tests as described below. +- Push your commits to your fork on GitHub and `create a pull request`_. Link to the + issue being addressed with ``fixes #123`` in the pull request description. .. code-block:: text - $ git push --set-upstream fork your-branch-name + $ git push --set-upstream origin your-branch-name -.. _committing as you go: https://dont-be-afraid-to-commit.readthedocs.io/en/latest/git/commandlinegit.html#commit-your-changes +.. _committing as you go: https://afraid-to-commit.readthedocs.io/en/latest/git/commandlinegit.html#commit-your-changes +.. _create a fork: https://docs.github.com/en/codespaces/developing-in-codespaces/using-source-control-in-your-codespace#about-automatic-forking .. _create a pull request: https://docs.github.com/en/github/collaborating-with-issues-and-pull-requests/creating-a-pull-request +.. _Running the tests: + Running the tests ~~~~~~~~~~~~~~~~~ diff --git a/MANIFEST.in b/MANIFEST.in deleted file mode 100644 index 8942481..0000000 --- a/MANIFEST.in +++ /dev/null @@ -1,12 +0,0 @@ -include CHANGES.rst -include tox.ini -include requirements/*.txt -graft artwork -graft docs -prune docs/_build -graft examples -graft tests -include src/werkzeug/py.typed -include src/werkzeug/*.pyi -graft src/werkzeug/debug/shared -global-exclude *.pyc diff --git a/README.rst b/README.rst index f1592a5..220c997 100644 --- a/README.rst +++ b/README.rst @@ -86,6 +86,4 @@ Links - PyPI Releases: https://pypi.org/project/Werkzeug/ - Source Code: https://github.com/pallets/werkzeug/ - Issue Tracker: https://github.com/pallets/werkzeug/issues/ -- Website: https://palletsprojects.com/p/werkzeug/ -- Twitter: https://twitter.com/PalletsTeam - Chat: https://discord.gg/pallets diff --git a/artwork/logo.png b/artwork/logo.png deleted file mode 100644 index 61666ab..0000000 Binary files a/artwork/logo.png and /dev/null differ diff --git a/artwork/logo.svg b/artwork/logo.svg deleted file mode 100644 index bd65219..0000000 --- a/artwork/logo.svg +++ /dev/null @@ -1,88 +0,0 @@ - - - - - - - - - image/svg+xml - - - - - - - - - - - diff --git a/debian/changelog b/debian/changelog deleted file mode 100644 index e70808f..0000000 --- a/debian/changelog +++ /dev/null @@ -1,29 +0,0 @@ -python-werkzeug (2.2.2-ok5) yangtze; urgency=medium - - * Update version info. - - -- sufang Tue, 21 Mar 2023 14:06:59 +0800 - -python-werkzeug (2.2.2-ok4) yangtze; urgency=medium - - * Add python3-jinja2 to build-depends. - - -- sufang Tue, 14 Mar 2023 16:11:45 +0800 - -python-werkzeug (2.2.2-ok3) yangtze; urgency=medium - - * Fix command 'install' has no such option 'install_layout'. - - -- sufang Tue, 14 Mar 2023 14:53:39 +0800 - -python-werkzeug (2.2.2-ok2) yangtze; urgency=medium - - * Apply patch. - - -- sufang Tue, 14 Mar 2023 14:51:23 +0800 - -python-werkzeug (2.2.2-ok1) yangtze; urgency=medium - - * Build for openkylin. - - -- sufang Mon, 30 Jan 2023 17:20:54 +0800 diff --git a/debian/control b/debian/control deleted file mode 100644 index 6b3519c..0000000 --- a/debian/control +++ /dev/null @@ -1,69 +0,0 @@ -Source: python-werkzeug -Section: python -Priority: optional -Maintainer: OpenKylin Developers -Standards-Version: 4.6.1 -Build-Depends: - debhelper-compat (= 13), - dh-python, - python3-all, - python3-cryptography , - python3-doc, - python3-ephemeral-port-reserve , - python3-greenlet , - python3-pallets-sphinx-themes , - python3-pytest , - python3-pytest-timeout , - python3-pytest-xprocess , - python3-setuptools, - python3-sphinx , - python3-sphinx-issues , - python3-sphinxcontrib-log-cabinet , - python3-watchdog , - python3-jinja2 -Homepage: http://werkzeug.pocoo.org/ -Vcs-Git: https://gitee.com/openkylin/python-werkzeug.git -Vcs-Browser: https://gitee.com/openkylin/python-werkzeug -Testsuite: autopkgtest-pkg-python -Rules-Requires-Root: no - -Package: python3-werkzeug -Architecture: all -Depends: - libjs-jquery, - ${misc:Depends}, - ${python3:Depends}, -Recommends: - python3-openssl, - python3-pyinotify, -Suggests: - ipython3, - python-werkzeug-doc, - python3-lxml, - python3-pkg-resources, - python3-watchdog, -Description: collection of utilities for WSGI applications (Python 3.x) - The Web Server Gateway Interface (WSGI) is a standard interface between web - server software and web applications written in Python. - . - Werkzeug is a lightweight library for interfacing with WSGI. It features - request and response objects, an interactive debugging system and a powerful - URI dispatcher. Combine with your choice of third party libraries and - middleware to easily create a custom application framework. - . - This package contains the Python 3.x module. - -Package: python-werkzeug-doc -Section: doc -Architecture: all -Depends: - ${misc:Depends}, - ${sphinxdoc:Depends}, -Multi-Arch: foreign -Description: documentation for the werkzeug Python library (docs) - Werkzeug is a lightweight library for interfacing with WSGI. It features - request and response objects, an interactive debugging system and a powerful - URI dispatcher. Combine with your choice of third party libraries and - middleware to easily create a custom application framework. - . - This package provides the Sphinx generated documentation for Werkzeug. diff --git a/debian/copyright b/debian/copyright deleted file mode 100644 index 6989e63..0000000 --- a/debian/copyright +++ /dev/null @@ -1,369 +0,0 @@ -Format: https://www.debian.org/doc/packaging-manuals/copyright-format/1.0/ -Upstream-Name: python-werkzeug -Source: -# -# Please double check copyright with the licensecheck(1) command. - -Files: .editorconfig - .gitattributes - .github/ISSUE_TEMPLATE/bug-report.md - .github/ISSUE_TEMPLATE/config.yml - .github/ISSUE_TEMPLATE/feature-request.md - .github/dependabot.yml - .github/workflows/lock.yaml - .github/workflows/tests.yaml - .gitignore - .pre-commit-config.yaml - .readthedocs.yaml - CHANGES.rst - CODE_OF_CONDUCT.md - CONTRIBUTING.rst - MANIFEST.in - README.rst - artwork/logo.png - docs/Makefile - docs/_static/debug-screenshot.png - docs/_static/favicon.ico - docs/_static/shortly.png - docs/_static/werkzeug.png - docs/changes.rst - docs/conf.py - docs/datastructures.rst - docs/debug.rst - docs/deployment/apache-httpd.rst - docs/deployment/eventlet.rst - docs/deployment/gevent.rst - docs/deployment/gunicorn.rst - docs/deployment/index.rst - docs/deployment/mod_wsgi.rst - docs/deployment/nginx.rst - docs/deployment/proxy_fix.rst - docs/deployment/uwsgi.rst - docs/deployment/waitress.rst - docs/exceptions.rst - docs/http.rst - docs/index.rst - docs/installation.rst - docs/levels.rst - docs/license.rst - docs/local.rst - docs/make.bat - docs/middleware/dispatcher.rst - docs/middleware/http_proxy.rst - docs/middleware/index.rst - docs/middleware/lint.rst - docs/middleware/profiler.rst - docs/middleware/proxy_fix.rst - docs/middleware/shared_data.rst - docs/quickstart.rst - docs/request_data.rst - docs/routing.rst - docs/serving.rst - docs/terms.rst - docs/test.rst - docs/tutorial.rst - docs/unicode.rst - docs/urls.rst - docs/utils.rst - docs/wrappers.rst - docs/wsgi.rst - examples/README.rst - examples/coolmagic/__init__.py - examples/coolmagic/application.py - examples/coolmagic/helpers.py - examples/coolmagic/public/style.css - examples/coolmagic/templates/static/about.html - examples/coolmagic/templates/static/index.html - examples/coolmagic/templates/static/not_found.html - examples/coolmagic/utils.py - examples/coolmagic/views/__init__.py - examples/coolmagic/views/static.py - examples/couchy/README - examples/couchy/__init__.py - examples/couchy/application.py - examples/couchy/models.py - examples/couchy/static/style.css - examples/couchy/templates/display.html - examples/couchy/templates/list.html - examples/couchy/templates/new.html - examples/couchy/templates/not_found.html - examples/couchy/utils.py - examples/couchy/views.py - examples/cupoftee/__init__.py - examples/cupoftee/application.py - examples/cupoftee/db.py - examples/cupoftee/network.py - examples/cupoftee/pages.py - examples/cupoftee/shared/content.png - examples/cupoftee/shared/down.png - examples/cupoftee/shared/favicon.ico - examples/cupoftee/shared/header.png - examples/cupoftee/shared/logo.png - examples/cupoftee/shared/style.css - examples/cupoftee/shared/up.png - examples/cupoftee/templates/missingpage.html - examples/cupoftee/templates/search.html - examples/cupoftee/templates/server.html - examples/cupoftee/templates/serverlist.html - examples/cupoftee/utils.py - examples/httpbasicauth.py - examples/i18nurls/__init__.py - examples/i18nurls/application.py - examples/i18nurls/templates/about.html - examples/i18nurls/templates/blog.html - examples/i18nurls/templates/index.html - examples/i18nurls/urls.py - examples/i18nurls/views.py - examples/manage-coolmagic.py - examples/manage-couchy.py - examples/manage-cupoftee.py - examples/manage-i18nurls.py - examples/manage-plnt.py - examples/manage-shorty.py - examples/manage-simplewiki.py - examples/manage-webpylike.py - examples/partial/README - examples/partial/complex_routing.py - examples/plnt/__init__.py - examples/plnt/database.py - examples/plnt/shared/style.css - examples/plnt/sync.py - examples/plnt/templates/about.html - examples/plnt/templates/index.html - examples/plnt/utils.py - examples/plnt/views.py - examples/plnt/webapp.py - examples/shortly/shortly.py - examples/shortly/static/style.css - examples/shortly/templates/404.html - examples/shortly/templates/new_url.html - examples/shortly/templates/short_link_details.html - examples/shorty/__init__.py - examples/shorty/application.py - examples/shorty/models.py - examples/shorty/static/style.css - examples/shorty/templates/display.html - examples/shorty/templates/list.html - examples/shorty/templates/new.html - examples/shorty/templates/not_found.html - examples/shorty/utils.py - examples/shorty/views.py - examples/simplewiki/__init__.py - examples/simplewiki/actions.py - examples/simplewiki/application.py - examples/simplewiki/database.py - examples/simplewiki/shared/style.css - examples/simplewiki/specialpages.py - examples/simplewiki/utils.py - examples/upload.py - examples/webpylike/example.py - examples/webpylike/webpylike.py - examples/wsecho.py - requirements/dev.in - requirements/dev.txt - requirements/docs.in - requirements/docs.txt - requirements/tests.in - requirements/tests.txt - requirements/typing.in - requirements/typing.txt - setup.cfg - setup.py - src/werkzeug/__init__.py - src/werkzeug/_internal.py - src/werkzeug/_reloader.py - src/werkzeug/datastructures.py - src/werkzeug/datastructures.pyi - src/werkzeug/debug/__init__.py - src/werkzeug/debug/console.py - src/werkzeug/debug/repr.py - src/werkzeug/debug/shared/ICON_LICENSE.md - src/werkzeug/debug/shared/console.png - src/werkzeug/debug/shared/debugger.js - src/werkzeug/debug/shared/less.png - src/werkzeug/debug/shared/more.png - src/werkzeug/debug/shared/style.css - src/werkzeug/debug/tbtools.py - src/werkzeug/exceptions.py - src/werkzeug/formparser.py - src/werkzeug/http.py - src/werkzeug/local.py - src/werkzeug/middleware/__init__.py - src/werkzeug/middleware/dispatcher.py - src/werkzeug/middleware/http_proxy.py - src/werkzeug/middleware/lint.py - src/werkzeug/middleware/profiler.py - src/werkzeug/middleware/proxy_fix.py - src/werkzeug/middleware/shared_data.py - src/werkzeug/py.typed - src/werkzeug/routing/__init__.py - src/werkzeug/routing/converters.py - src/werkzeug/routing/exceptions.py - src/werkzeug/routing/map.py - src/werkzeug/routing/matcher.py - src/werkzeug/routing/rules.py - src/werkzeug/sansio/__init__.py - src/werkzeug/sansio/http.py - src/werkzeug/sansio/multipart.py - src/werkzeug/sansio/request.py - src/werkzeug/sansio/response.py - src/werkzeug/sansio/utils.py - src/werkzeug/security.py - src/werkzeug/serving.py - src/werkzeug/test.py - src/werkzeug/testapp.py - src/werkzeug/urls.py - src/werkzeug/user_agent.py - src/werkzeug/utils.py - src/werkzeug/wrappers/__init__.py - src/werkzeug/wrappers/request.py - src/werkzeug/wrappers/response.py - src/werkzeug/wsgi.py - tests/conftest.py - tests/live_apps/data_app.py - tests/live_apps/reloader_app.py - tests/live_apps/run.py - tests/live_apps/standard_app.py - tests/live_apps/streaming_app.py - tests/middleware/test_dispatcher.py - tests/middleware/test_http_proxy.py - tests/middleware/test_lint.py - tests/middleware/test_proxy_fix.py - tests/middleware/test_shared_data.py - tests/multipart/firefox3-2png1txt/file1.png - tests/multipart/firefox3-2png1txt/file2.png - tests/multipart/firefox3-2png1txt/request.http - tests/multipart/firefox3-2png1txt/text.txt - tests/multipart/firefox3-2pnglongtext/file1.png - tests/multipart/firefox3-2pnglongtext/file2.png - tests/multipart/firefox3-2pnglongtext/request.http - tests/multipart/firefox3-2pnglongtext/text.txt - tests/multipart/ie6-2png1txt/file1.png - tests/multipart/ie6-2png1txt/file2.png - tests/multipart/ie6-2png1txt/request.http - tests/multipart/ie6-2png1txt/text.txt - tests/multipart/ie7_full_path_request.http - tests/multipart/opera8-2png1txt/file1.png - tests/multipart/opera8-2png1txt/file2.png - tests/multipart/opera8-2png1txt/request.http - tests/multipart/opera8-2png1txt/text.txt - tests/multipart/webkit3-2png1txt/file1.png - tests/multipart/webkit3-2png1txt/file2.png - tests/multipart/webkit3-2png1txt/request.http - tests/multipart/webkit3-2png1txt/text.txt - tests/res/test.txt - tests/sansio/__init__.py - tests/sansio/test_multipart.py - tests/sansio/test_request.py - tests/sansio/test_utils.py - tests/test_datastructures.py - tests/test_debug.py - tests/test_exceptions.py - tests/test_formparser.py - tests/test_http.py - tests/test_internal.py - tests/test_local.py - tests/test_routing.py - tests/test_security.py - tests/test_send_file.py - tests/test_serving.py - tests/test_test.py - tests/test_urls.py - tests/test_utils.py - tests/test_wrappers.py - tests/test_wsgi.py - tox.ini -Copyright: __NO_COPYRIGHT_NOR_LICENSE__ -License: __NO_COPYRIGHT_NOR_LICENSE__ - -Files: LICENSE.rst -Copyright: 2007 Pallets -License: BSD-3-Clause - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - . - 1. Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - . - 2. Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - . - 3. Neither the name of the copyright holder nor the names of its - contributors may be used to endorse or promote products derived from - this software without specific prior written permission. - . - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A - PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED - TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR - PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF - LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING - NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS - SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - . - On Debian systems, the complete text of the BSD 3-clause "New" or "Revised" - License can be found in `/usr/share/common-licenses/BSD'. - -#---------------------------------------------------------------------------- -# xml and html files (skipped): -# tests/res/index.html -# examples/shortly/templates/layout.html -# examples/coolmagic/templates/layout.html -# examples/plnt/templates/layout.html -# examples/simplewiki/templates/action_revert.html -# examples/simplewiki/templates/page_index.html -# examples/simplewiki/templates/page_missing.html -# examples/simplewiki/templates/action_edit.html -# examples/simplewiki/templates/action_log.html -# examples/simplewiki/templates/action_show.html -# examples/simplewiki/templates/macros.xml -# examples/simplewiki/templates/recent_changes.html -# examples/simplewiki/templates/missing_action.html -# examples/simplewiki/templates/layout.html -# examples/simplewiki/templates/action_diff.html -# examples/i18nurls/templates/layout.html -# examples/shorty/templates/layout.html -# examples/cupoftee/templates/layout.html -# examples/couchy/templates/layout.html -# .github/pull_request_template.md -# artwork/logo.svg - -#---------------------------------------------------------------------------- -# Files marked as NO_LICENSE_TEXT_FOUND may be covered by the following -# license/copyright files. - -#---------------------------------------------------------------------------- -# License file: LICENSE.rst - Copyright 2007 Pallets - . - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are - met: - . - 1. Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - . - 2. Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - . - 3. Neither the name of the copyright holder nor the names of its - contributors may be used to endorse or promote products derived from - this software without specific prior written permission. - . - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A - PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED - TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR - PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF - LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING - NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS - SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/debian/patches/series b/debian/patches/series deleted file mode 100644 index 4a97dfa..0000000 --- a/debian/patches/series +++ /dev/null @@ -1 +0,0 @@ -# You must remove unused comment lines for the released package. diff --git a/debian/python-werkzeug-doc.doc-base b/debian/python-werkzeug-doc.doc-base deleted file mode 100644 index 2e5fb4a..0000000 --- a/debian/python-werkzeug-doc.doc-base +++ /dev/null @@ -1,10 +0,0 @@ -Document: werkzeug -Title: Werkzeug Documentation -Author: Armin Ronacher -Abstract: This document describes Werkzeug - collection of utilities for WSGI - applications written in Python. -Section: Programming/Python - -Format: HTML -Index: /usr/share/doc/python-werkzeug-doc/html/index.html -Files: /usr/share/doc/python-werkzeug-doc/html/*.html diff --git a/debian/python-werkzeug-doc.examples b/debian/python-werkzeug-doc.examples deleted file mode 100644 index e39721e..0000000 --- a/debian/python-werkzeug-doc.examples +++ /dev/null @@ -1 +0,0 @@ -examples/* diff --git a/debian/python-werkzeug-doc.links b/debian/python-werkzeug-doc.links deleted file mode 100644 index 2c69103..0000000 --- a/debian/python-werkzeug-doc.links +++ /dev/null @@ -1,6 +0,0 @@ -/usr/share/doc/python-werkzeug-doc/examples /usr/share/doc/python-werkzeug/examples -/usr/share/doc/python-werkzeug-doc/examples /usr/share/doc/python3-werkzeug/examples -/usr/share/doc/python-werkzeug-doc/html /usr/share/doc/python-werkzeug/html -/usr/share/doc/python-werkzeug-doc/html /usr/share/doc/python3-werkzeug/html -/usr/share/doc/python-werkzeug-doc/html/_sources /usr/share/doc/python-werkzeug/rst -/usr/share/doc/python-werkzeug-doc/html/_sources /usr/share/doc/python3-werkzeug/rst diff --git a/debian/python3-werkzeug.links b/debian/python3-werkzeug.links deleted file mode 100644 index 01ca329..0000000 --- a/debian/python3-werkzeug.links +++ /dev/null @@ -1 +0,0 @@ -/usr/share/javascript/jquery/jquery.js /usr/lib/python3/dist-packages/werkzeug/debug/shared/jquery.js diff --git a/debian/rules b/debian/rules deleted file mode 100755 index a3508a5..0000000 --- a/debian/rules +++ /dev/null @@ -1,39 +0,0 @@ -#!/usr/bin/make -f - -# Copyright 2009, Noah Slater - -# Copying and distribution of this file, with or without modification, are -# permitted in any medium without royalty provided the copyright notice and this -# notice are preserved. - -export PYBUILD_NAME=werkzeug -export PYBUILD_TEST_PYTEST=1 -export SETUPTOOLS_USE_DISTUTILS=stdlib - -%: - dh $@ --with python3,sphinxdoc --buildsystem pybuild - -override_dh_auto_clean: - make -C docs clean - rm -rf build Werkzeug.egg-info/ - #find $(CURDIR) \( -name '\._*' -o -name '\.DS_Store' \) -delete - find . -iname '__pycache__' -exec rm -rf {} \; || true - rm -rf .pytest_cache - dh_auto_clean - -override_dh_fixperms: - find debian/ -name '*\.png' -exec chmod -x '{}' \; - dh_fixperms - -override_dh_installdocs: - dh_installdocs --doc-main-package=python-werkzeug-doc -ppython-werkzeug-doc - dh_installdocs - -override_dh_installexamples: - dh_installexamples --doc-main-package=python-werkzeug-doc -ppython-werkzeug-doc - -override_dh_sphinxdoc: -ifeq (,$(findstring nodocs, $(DEB_BUILD_OPTIONS))) - PYTHONPATH=src python3 -m sphinx -b html docs/ debian/python-werkzeug-doc/usr/share/doc/python-werkzeug-doc/html/ - dh_sphinxdoc -endif diff --git a/debian/source/format b/debian/source/format deleted file mode 100644 index 89ae9db..0000000 --- a/debian/source/format +++ /dev/null @@ -1 +0,0 @@ -3.0 (native) diff --git a/debian/tests/control b/debian/tests/control deleted file mode 100644 index bfa14fd..0000000 --- a/debian/tests/control +++ /dev/null @@ -1,5 +0,0 @@ -Tests: upstream -Depends: - @, - @builddeps@, -Restrictions: allow-stderr diff --git a/debian/tests/upstream b/debian/tests/upstream deleted file mode 100755 index ad3ff8b..0000000 --- a/debian/tests/upstream +++ /dev/null @@ -1,14 +0,0 @@ -#!/bin/sh -set -eu - -export LC_ALL=C.UTF-8 -pyvers=$(py3versions -r 2>/dev/null) - -cp -a tests "$AUTOPKGTEST_TMP" -cd "$AUTOPKGTEST_TMP" - -for py in ${pyvers}; do - echo "-=-=-=-=-=-=-=- running tests for ${py} -=-=-=-=-=-=-=-=-" - printf '$ %s\n' "${py} -m pytest tests" - ${py} -m pytest tests -done diff --git a/debian/upstream/metadata b/debian/upstream/metadata deleted file mode 100644 index 2c2d29c..0000000 --- a/debian/upstream/metadata +++ /dev/null @@ -1,4 +0,0 @@ -Bug-Database: https://github.com/pallets/werkzeug/issues -Bug-Submit: https://github.com/pallets/werkzeug/issues/new -Repository: https://github.com/pallets/werkzeug.git -Repository-Browse: https://github.com/pallets/werkzeug diff --git a/debian/watch b/debian/watch deleted file mode 100644 index 8b2c992..0000000 --- a/debian/watch +++ /dev/null @@ -1,6 +0,0 @@ -version=3 -opts=uversionmangle=s/(rc|a|b|c)/~$1/,\ -dversionmangle=auto,\ -repack,\ -filenamemangle=s/.+\/v?(\d\S*)\.tar\.gz/werkzeug-$1\.tar\.gz/ \ -https://github.com/pallets/werkzeug/tags .*/v?(\d\S*)\.tar\.gz diff --git a/docs/_static/favicon.ico b/docs/_static/favicon.ico deleted file mode 100644 index a3b079a..0000000 Binary files a/docs/_static/favicon.ico and /dev/null differ diff --git a/docs/_static/shortcut-icon.png b/docs/_static/shortcut-icon.png new file mode 100644 index 0000000..37cf028 Binary files /dev/null and b/docs/_static/shortcut-icon.png differ diff --git a/docs/_static/werkzeug-horizontal.png b/docs/_static/werkzeug-horizontal.png new file mode 100644 index 0000000..0581470 Binary files /dev/null and b/docs/_static/werkzeug-horizontal.png differ diff --git a/docs/_static/werkzeug-vertical.png b/docs/_static/werkzeug-vertical.png new file mode 100644 index 0000000..be2a7a3 Binary files /dev/null and b/docs/_static/werkzeug-vertical.png differ diff --git a/docs/_static/werkzeug.png b/docs/_static/werkzeug.png deleted file mode 100644 index 9cedb06..0000000 Binary files a/docs/_static/werkzeug.png and /dev/null differ diff --git a/docs/conf.py b/docs/conf.py index 96e998b..e09ef8f 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -26,14 +26,13 @@ issues_github_path = "pallets/werkzeug" # HTML ----------------------------------------------------------------- html_theme = "werkzeug" +html_theme_options = {"index_sidebar_logo": False} html_context = { "project_links": [ ProjectLink("Donate", "https://palletsprojects.com/donate"), ProjectLink("PyPI Releases", "https://pypi.org/project/Werkzeug/"), ProjectLink("Source Code", "https://github.com/pallets/werkzeug/"), ProjectLink("Issue Tracker", "https://github.com/pallets/werkzeug/issues/"), - ProjectLink("Website", "https://palletsprojects.com/p/werkzeug/"), - ProjectLink("Twitter", "https://twitter.com/PalletsTeam"), ProjectLink("Chat", "https://discord.gg/pallets"), ] } @@ -43,8 +42,8 @@ html_sidebars = { } singlehtml_sidebars = {"index": ["project.html", "localtoc.html", "ethicalads.html"]} html_static_path = ["_static"] -html_favicon = "_static/favicon.ico" -html_logo = "_static/werkzeug.png" +html_favicon = "_static/shortcut-icon.png" +html_logo = "_static/werkzeug-vertical.png" html_title = f"Werkzeug Documentation ({version})" html_show_sourcelink = False diff --git a/docs/http.rst b/docs/http.rst index cbf4e04..790de31 100644 --- a/docs/http.rst +++ b/docs/http.rst @@ -53,10 +53,6 @@ by :rfc:`2616`, Werkzeug implements some custom data structures that are .. autofunction:: parse_cache_control_header -.. autofunction:: parse_authorization_header - -.. autofunction:: parse_www_authenticate_header - .. autofunction:: parse_if_range_header .. autofunction:: parse_range_header diff --git a/docs/index.rst b/docs/index.rst index c4f0019..4bc4e30 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -1,6 +1,12 @@ +.. rst-class:: hide-header + Werkzeug ======== +.. image:: _static/werkzeug-horizontal.png + :align: center + :target: https://werkzeug.palletsprojects.com + *werkzeug* German noun: "tool". Etymology: *werk* ("work"), *zeug* ("stuff") @@ -72,7 +78,6 @@ Additional Information :maxdepth: 2 terms - unicode request_data license changes diff --git a/docs/installation.rst b/docs/installation.rst index 9c5aa7f..7138f08 100644 --- a/docs/installation.rst +++ b/docs/installation.rst @@ -6,13 +6,7 @@ Python Version -------------- We recommend using the latest version of Python. Werkzeug supports -Python 3.7 and newer. - - -Dependencies ------------- - -Werkzeug does not have any direct dependencies. +Python 3.8 and newer. Optional dependencies diff --git a/docs/middleware/index.rst b/docs/middleware/index.rst index 70cddee..3d7ede4 100644 --- a/docs/middleware/index.rst +++ b/docs/middleware/index.rst @@ -1 +1,20 @@ -.. automodule:: werkzeug.middleware +Middleware +========== + +A WSGI middleware is a WSGI application that wraps another application +in order to observe or change its behavior. Werkzeug provides some +middleware for common use cases. + +.. toctree:: + :maxdepth: 1 + + proxy_fix + shared_data + dispatcher + http_proxy + lint + profiler + +The :doc:`interactive debugger ` is also a middleware that can +be applied manually, although it is typically used automatically with +the :doc:`development server `. diff --git a/docs/quickstart.rst b/docs/quickstart.rst index 1568892..0f3714e 100644 --- a/docs/quickstart.rst +++ b/docs/quickstart.rst @@ -43,9 +43,7 @@ there: >>> request = Request(environ) Now you can access the important variables and Werkzeug will parse them -for you and decode them where it makes sense. The default charset for -requests is set to `utf-8` but you can change that by subclassing -:class:`Request`. +for you and decode them where it makes sense. >>> request.path '/foo' diff --git a/docs/request_data.rst b/docs/request_data.rst index 83c6278..b1c97b2 100644 --- a/docs/request_data.rst +++ b/docs/request_data.rst @@ -73,23 +73,31 @@ read the stream *or* call :meth:`~Request.get_data`. Limiting Request Data --------------------- -To avoid being the victim of a DDOS attack you can set the maximum -accepted content length and request field sizes. The :class:`Request` -class has two attributes for that: :attr:`~Request.max_content_length` -and :attr:`~Request.max_form_memory_size`. +The :class:`Request` class provides a few attributes to control how much data is +processed from the request body. This can help mitigate DoS attacks that craft the +request in such a way that the server uses too many resources to handle it. Each of +these limits will raise a :exc:`~werkzeug.exceptions.RequestEntityTooLarge` if they are +exceeded. -The first one can be used to limit the total content length. For example -by setting it to ``1024 * 1024 * 16`` the request won't accept more than -16MB of transmitted data. +- :attr:`~Request.max_content_length` Stop reading request data after this number + of bytes. It's better to configure this in the WSGI server or HTTP server, rather + than the WSGI application. +- :attr:`~Request.max_form_memory_size` Stop reading request data if any form part is + larger than this number of bytes. While file parts can be moved to disk, regular + form field data is stored in memory only. +- :attr:`~Request.max_form_parts` Stop reading request data if more than this number + of parts are sent in multipart form data. This is useful to stop a very large number + of very small parts, especially file parts. The default is 1000. -Because certain data can't be moved to the hard disk (regular post data) -whereas temporary files can, there is a second limit you can set. The -:attr:`~Request.max_form_memory_size` limits the size of `POST` -transmitted form data. By setting it to ``1024 * 1024 * 2`` you can make -sure that all in memory-stored fields are not more than 2MB in size. +Using Werkzeug to set these limits is only one layer of protection. WSGI servers +and HTTPS servers should set their own limits on size and timeouts. The operating system +or container manager should set limits on memory and processing time for server +processes. -This however does *not* affect in-memory stored files if the -`stream_factory` used returns a in-memory file. +If a 413 Content Too Large error is returned before the entire request is read, clients +may show a "connection reset" failure instead of the 413 error. This is based on how the +WSGI/HTTP server and client handle connections, it's not something the WSGI application +(Werkzeug) has control over. How to extend Parsing? diff --git a/docs/test.rst b/docs/test.rst index efb449a..d31ac59 100644 --- a/docs/test.rst +++ b/docs/test.rst @@ -18,8 +18,8 @@ requests. >>> response = c.get("/") >>> response.status_code 200 ->>> resp.headers -Headers([('Content-Type', 'text/html; charset=utf-8'), ('Content-Length', '6658')]) +>>> response.headers +Headers([('Content-Type', 'text/html; charset=utf-8'), ('Content-Length', '5211')]) >>> response.get_data(as_text=True) '...' @@ -102,6 +102,10 @@ API :members: :member-order: bysource +.. autoclass:: Cookie + :members: + :member-order: bysource + .. autoclass:: EnvironBuilder :members: :member-order: bysource diff --git a/docs/unicode.rst b/docs/unicode.rst deleted file mode 100644 index 30f76f5..0000000 --- a/docs/unicode.rst +++ /dev/null @@ -1,76 +0,0 @@ -Unicode -======= - -.. currentmodule:: werkzeug - -Werkzeug uses strings internally everwhere text data is assumed, even if -the HTTP standard is not Unicode aware. Basically all incoming data is -decoded from the charset (UTF-8 by default) so that you don't work with -bytes directly. Outgoing data is encoded into the target charset. - - -Unicode in Python ------------------ - -Imagine you have the German Umlaut ``ö``. In ASCII you cannot represent -that character, but in the ``latin-1`` and ``utf-8`` character sets you -can represent it, but they look different when encoded: - ->>> "ö".encode("latin1") -b'\xf6' ->>> "ö".encode("utf-8") -b'\xc3\xb6' - -An ``ö`` looks different depending on the encoding which makes it hard -to work with it as bytes. Instead, Python treats strings as Unicode text -and stores the information ``LATIN SMALL LETTER O WITH DIAERESIS`` -instead of the bytes for ``ö`` in a specific encoding. The length of a -string with 1 character will be 1, where the length of the bytes might -be some other value. - - -Unicode in HTTP ---------------- - -However, the HTTP spec was written in a time where ASCII bytes were the -common way data was represented. To work around this for the modern -web, Werkzeug decodes and encodes incoming and outgoing data -automatically. Data sent from the browser to the web application is -decoded from UTF-8 bytes into a string. Data sent from the application -back to the browser is encoded back to UTF-8. - - -Error Handling --------------- - -Functions that do internal encoding or decoding accept an ``errors`` -keyword argument that is passed to :meth:`str.decode` and -:meth:`str.encode`. The default is ``'replace'`` so that errors are easy -to spot. It might be useful to set it to ``'strict'`` in order to catch -the error and report the bad data to the client. - - -Request and Response Objects ----------------------------- - -In most cases, you should stick with Werkzeug's default encoding of -UTF-8. If you have a specific reason to, you can subclass -:class:`wrappers.Request` and :class:`wrappers.Response` to change the -encoding and error handling. - -.. code-block:: python - - from werkzeug.wrappers.request import Request - from werkzeug.wrappers.response import Response - - class Latin1Request(Request): - charset = "latin1" - encoding_errors = "strict" - - class Latin1Response(Response): - charset = "latin1" - -The error handling can only be changed for the request. Werkzeug will -always raise errors when encoding to bytes in the response. It's your -responsibility to not create data that is not present in the target -charset. This is not an issue for UTF-8. diff --git a/docs/utils.rst b/docs/utils.rst index 0d4e339..6afa4ab 100644 --- a/docs/utils.rst +++ b/docs/utils.rst @@ -23,6 +23,8 @@ General Helpers .. autofunction:: send_file +.. autofunction:: send_from_directory + .. autofunction:: import_string .. autofunction:: find_modules diff --git a/docs/wsgi.rst b/docs/wsgi.rst index a96916b..67b3bb6 100644 --- a/docs/wsgi.rst +++ b/docs/wsgi.rst @@ -22,10 +22,6 @@ iterator and the input stream. .. autoclass:: LimitedStream :members: -.. autofunction:: make_line_iter - -.. autofunction:: make_chunk_iter - .. autofunction:: wrap_file @@ -43,18 +39,6 @@ information or perform common manipulations: .. autofunction:: get_current_url -.. autofunction:: get_query_string - -.. autofunction:: get_script_name - -.. autofunction:: get_path_info - -.. autofunction:: pop_path_info - -.. autofunction:: peek_path_info - -.. autofunction:: extract_path_info - .. autofunction:: host_is_trusted diff --git a/examples/couchy/utils.py b/examples/couchy/utils.py index 03d1681..5c39fdf 100644 --- a/examples/couchy/utils.py +++ b/examples/couchy/utils.py @@ -1,6 +1,7 @@ from os import path from random import randrange from random import sample +from urllib.parse import urlsplit from jinja2 import Environment from jinja2 import FileSystemLoader @@ -8,7 +9,6 @@ from werkzeug.local import Local from werkzeug.local import LocalManager from werkzeug.routing import Map from werkzeug.routing import Rule -from werkzeug.urls import url_parse from werkzeug.utils import cached_property from werkzeug.wrappers import Response @@ -49,7 +49,7 @@ def render_template(template, **context): def validate_url(url): - return url_parse(url)[0] in ALLOWED_SCHEMES + return urlsplit(url)[0] in ALLOWED_SCHEMES def get_random_uid(): diff --git a/examples/shortly/shortly.py b/examples/shortly/shortly.py index 10e957e..5205f22 100644 --- a/examples/shortly/shortly.py +++ b/examples/shortly/shortly.py @@ -1,5 +1,6 @@ """A simple URL shortener using Werkzeug and redis.""" import os +from urllib.parse import urlsplit import redis from jinja2 import Environment @@ -9,7 +10,6 @@ from werkzeug.exceptions import NotFound from werkzeug.middleware.shared_data import SharedDataMiddleware from werkzeug.routing import Map from werkzeug.routing import Rule -from werkzeug.urls import url_parse from werkzeug.utils import redirect from werkzeug.wrappers import Request from werkzeug.wrappers import Response @@ -27,12 +27,12 @@ def base36_encode(number): def is_valid_url(url): - parts = url_parse(url) + parts = urlsplit(url) return parts.scheme in ("http", "https") def get_hostname(url): - return url_parse(url).netloc + return urlsplit(url).netloc class Shortly: diff --git a/examples/shorty/utils.py b/examples/shorty/utils.py index 2d9fe0e..4d064e3 100644 --- a/examples/shorty/utils.py +++ b/examples/shorty/utils.py @@ -1,6 +1,7 @@ from os import path from random import randrange from random import sample +from urllib.parse import urlsplit from jinja2 import Environment from jinja2 import FileSystemLoader @@ -11,7 +12,6 @@ from werkzeug.local import Local from werkzeug.local import LocalManager from werkzeug.routing import Map from werkzeug.routing import Rule -from werkzeug.urls import url_parse from werkzeug.utils import cached_property from werkzeug.wrappers import Response @@ -59,7 +59,7 @@ def render_template(template, **context): def validate_url(url): - return url_parse(url)[0] in ALLOWED_SCHEMES + return urlsplit(url)[0] in ALLOWED_SCHEMES def get_random_uid(): diff --git a/examples/simplewiki/utils.py b/examples/simplewiki/utils.py index 6cafab4..00729c6 100644 --- a/examples/simplewiki/utils.py +++ b/examples/simplewiki/utils.py @@ -1,12 +1,12 @@ from os import path +from urllib.parse import quote +from urllib.parse import urlencode import creoleparser from genshi import Stream from genshi.template import TemplateLoader from werkzeug.local import Local from werkzeug.local import LocalManager -from werkzeug.urls import url_encode -from werkzeug.urls import url_quote from werkzeug.utils import cached_property from werkzeug.wrappers import Request as BaseRequest from werkzeug.wrappers import Response as BaseResponse @@ -58,9 +58,9 @@ def href(*args, **kw): """ result = [f"{request.script_root if request else ''}/"] for idx, arg in enumerate(args): - result.append(f"{'/' if idx else ''}{url_quote(arg)}") + result.append(f"{'/' if idx else ''}{quote(arg)}") if kw: - result.append(f"?{url_encode(kw)}") + result.append(f"?{urlencode(kw)}") return "".join(result) diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..70721a9 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,104 @@ +[project] +name = "Werkzeug" +version = "3.0.1" +description = "The comprehensive WSGI web application library." +readme = "README.rst" +license = {file = "LICENSE.rst"} +maintainers = [{name = "Pallets", email = "contact@palletsprojects.com"}] +classifiers = [ + "Development Status :: 5 - Production/Stable", + "Environment :: Web Environment", + "Intended Audience :: Developers", + "License :: OSI Approved :: BSD License", + "Operating System :: OS Independent", + "Programming Language :: Python", + "Topic :: Internet :: WWW/HTTP :: Dynamic Content", + "Topic :: Internet :: WWW/HTTP :: WSGI", + "Topic :: Internet :: WWW/HTTP :: WSGI :: Application", + "Topic :: Internet :: WWW/HTTP :: WSGI :: Middleware", + "Topic :: Software Development :: Libraries :: Application Frameworks", +] +requires-python = ">=3.8" +dependencies = ["MarkupSafe>=2.1.1"] + +[project.urls] +Donate = "https://palletsprojects.com/donate" +Documentation = "https://werkzeug.palletsprojects.com/" +Changes = "https://werkzeug.palletsprojects.com/changes/" +"Source Code" = "https://github.com/pallets/werkzeug/" +"Issue Tracker" = "https://github.com/pallets/werkzeug/issues/" +Chat = "https://discord.gg/pallets" + +[project.optional-dependencies] +watchdog = ["watchdog>=2.3"] + +[build-system] +requires = ["flit_core<4"] +build-backend = "flit_core.buildapi" + +[tool.flit.module] +name = "werkzeug" + +[tool.flit.sdist] +include = [ + "docs/", + "examples/", + "requirements/", + "tests/", + "CHANGES.rst", + "tox.ini", +] +exclude = [ + "docs/_build/", +] + +[tool.pytest.ini_options] +testpaths = ["tests"] +filterwarnings = [ + "error", +] +markers = ["dev_server: tests that start the dev server"] + +[tool.coverage.run] +branch = true +source = ["werkzeug", "tests"] + +[tool.coverage.paths] +source = ["src", "*/site-packages"] + +[tool.mypy] +python_version = "3.8" +files = ["src/werkzeug"] +show_error_codes = true +pretty = true +#strict = true +allow_redefinition = true +disallow_subclassing_any = true +#disallow_untyped_calls = true +disallow_untyped_defs = true +disallow_incomplete_defs = true +no_implicit_optional = true +local_partial_types = true +no_implicit_reexport = true +strict_equality = true +warn_redundant_casts = true +warn_unused_configs = true +warn_unused_ignores = true +warn_return_any = true +#warn_unreachable = True + +[[tool.mypy.overrides]] +module = ["werkzeug.wrappers"] +no_implicit_reexport = false + +[[tool.mypy.overrides]] +module = [ + "colorama.*", + "cryptography.*", + "eventlet.*", + "gevent.*", + "greenlet.*", + "watchdog.*", + "xprocess.*", +] +ignore_missing_imports = true diff --git a/requirements/build.in b/requirements/build.in new file mode 100644 index 0000000..378eac2 --- /dev/null +++ b/requirements/build.in @@ -0,0 +1 @@ +build diff --git a/requirements/build.txt b/requirements/build.txt new file mode 100644 index 0000000..196545d --- /dev/null +++ b/requirements/build.txt @@ -0,0 +1,13 @@ +# SHA1:80754af91bfb6d1073585b046fe0a474ce868509 +# +# This file is autogenerated by pip-compile-multi +# To update, run: +# +# pip-compile-multi +# +build==0.10.0 + # via -r requirements/build.in +packaging==23.1 + # via build +pyproject-hooks==1.0.0 + # via build diff --git a/requirements/dev.txt b/requirements/dev.txt index 50e233e..ed46208 100644 --- a/requirements/dev.txt +++ b/requirements/dev.txt @@ -8,55 +8,55 @@ -r docs.txt -r tests.txt -r typing.txt -build==0.8.0 +build==0.10.0 # via pip-tools +cachetools==5.3.1 + # via tox cfgv==3.3.1 # via pre-commit +chardet==5.1.0 + # via tox click==8.1.3 # via # pip-compile-multi # pip-tools -distlib==0.3.4 +colorama==0.4.6 + # via tox +distlib==0.3.6 # via virtualenv -filelock==3.7.1 +filelock==3.12.2 # via # tox # virtualenv -greenlet==1.1.2 ; python_version < "3.11" - # via -r requirements/tests.in -identify==2.5.1 +identify==2.5.24 # via pre-commit -nodeenv==1.7.0 +nodeenv==1.8.0 # via pre-commit -pep517==0.12.0 - # via build -pip-compile-multi==2.4.5 +pip-compile-multi==2.6.3 # via -r requirements/dev.in -pip-tools==6.8.0 +pip-tools==6.13.0 # via pip-compile-multi -platformdirs==2.5.2 - # via virtualenv -pre-commit==2.20.0 +platformdirs==3.8.0 + # via + # tox + # virtualenv +pre-commit==3.3.3 # via -r requirements/dev.in +pyproject-api==1.5.2 + # via tox +pyproject-hooks==1.0.0 + # via build pyyaml==6.0 # via pre-commit -six==1.16.0 - # via - # tox - # virtualenv -toml==0.10.2 - # via - # pre-commit - # tox -toposort==1.7 +toposort==1.10 # via pip-compile-multi -tox==3.25.1 +tox==4.6.3 # via -r requirements/dev.in -virtualenv==20.15.1 +virtualenv==20.23.1 # via # pre-commit # tox -wheel==0.37.1 +wheel==0.40.0 # via pip-tools # The following packages are considered to be unsafe in a requirements file: diff --git a/requirements/docs.txt b/requirements/docs.txt index 8238e78..e125c59 100644 --- a/requirements/docs.txt +++ b/requirements/docs.txt @@ -5,41 +5,37 @@ # # pip-compile-multi # -alabaster==0.7.12 +alabaster==0.7.13 # via sphinx -babel==2.10.3 +babel==2.12.1 # via sphinx -certifi==2022.6.15 +certifi==2023.5.7 # via requests -charset-normalizer==2.1.0 +charset-normalizer==3.1.0 # via requests -docutils==0.18.1 +docutils==0.20.1 # via sphinx -idna==3.3 +idna==3.4 # via requests imagesize==1.4.1 # via sphinx jinja2==3.1.2 # via sphinx -markupsafe==2.1.1 +markupsafe==2.1.3 # via jinja2 -packaging==21.3 +packaging==23.1 # via # pallets-sphinx-themes # sphinx -pallets-sphinx-themes==2.0.2 +pallets-sphinx-themes==2.1.1 # via -r requirements/docs.in -pygments==2.12.0 +pygments==2.15.1 # via sphinx -pyparsing==3.0.9 - # via packaging -pytz==2022.1 - # via babel -requests==2.28.1 +requests==2.31.0 # via sphinx snowballstemmer==2.2.0 # via sphinx -sphinx==5.0.2 +sphinx==7.0.1 # via # -r requirements/docs.in # pallets-sphinx-themes @@ -47,11 +43,11 @@ sphinx==5.0.2 # sphinxcontrib-log-cabinet sphinx-issues==3.0.1 # via -r requirements/docs.in -sphinxcontrib-applehelp==1.0.2 +sphinxcontrib-applehelp==1.0.4 # via sphinx sphinxcontrib-devhelp==1.0.2 # via sphinx -sphinxcontrib-htmlhelp==2.0.0 +sphinxcontrib-htmlhelp==2.0.1 # via sphinx sphinxcontrib-jsmath==1.0.1 # via sphinx @@ -61,5 +57,5 @@ sphinxcontrib-qthelp==1.0.3 # via sphinx sphinxcontrib-serializinghtml==1.1.5 # via sphinx -urllib3==1.26.10 +urllib3==2.0.3 # via requests diff --git a/requirements/tests.txt b/requirements/tests.txt index 689d8ba..057d628 100644 --- a/requirements/tests.txt +++ b/requirements/tests.txt @@ -5,40 +5,32 @@ # # pip-compile-multi # -attrs==21.4.0 - # via pytest cffi==1.15.1 # via cryptography -cryptography==37.0.4 +cryptography==41.0.1 # via -r requirements/tests.in ephemeral-port-reserve==1.1.4 # via -r requirements/tests.in -greenlet==1.1.2 ; python_version < "3.11" - # via -r requirements/tests.in -iniconfig==1.1.1 +iniconfig==2.0.0 # via pytest -packaging==21.3 +packaging==23.1 # via pytest -pluggy==1.0.0 +pluggy==1.2.0 # via pytest -psutil==5.9.1 +psutil==5.9.5 # via pytest-xprocess py==1.11.0 - # via pytest + # via pytest-xprocess pycparser==2.21 # via cffi -pyparsing==3.0.9 - # via packaging -pytest==7.1.2 +pytest==7.4.0 # via # -r requirements/tests.in # pytest-timeout # pytest-xprocess pytest-timeout==2.1.0 # via -r requirements/tests.in -pytest-xprocess==0.19.0 +pytest-xprocess==0.22.2 # via -r requirements/tests.in -tomli==2.0.1 - # via pytest -watchdog==2.1.9 +watchdog==3.0.0 # via -r requirements/tests.in diff --git a/requirements/typing.in b/requirements/typing.in index e17c43d..23ab158 100644 --- a/requirements/typing.in +++ b/requirements/typing.in @@ -2,3 +2,4 @@ mypy types-contextvars types-dataclasses types-setuptools +watchdog diff --git a/requirements/typing.txt b/requirements/typing.txt index 1f6de2c..99c46d2 100644 --- a/requirements/typing.txt +++ b/requirements/typing.txt @@ -1,21 +1,21 @@ -# SHA1:95499f7e92b572adde012b13e1ec99dbbb2f7089 +# SHA1:162796b1b3ac7a29da65fe0e32278f14b68ed8c8 # # This file is autogenerated by pip-compile-multi # To update, run: # # pip-compile-multi # -mypy==0.961 +mypy==1.4.1 # via -r requirements/typing.in -mypy-extensions==0.4.3 +mypy-extensions==1.0.0 # via mypy -tomli==2.0.1 - # via mypy -types-contextvars==2.4.7 +types-contextvars==2.4.7.2 # via -r requirements/typing.in types-dataclasses==0.6.6 # via -r requirements/typing.in -types-setuptools==62.6.1 +types-setuptools==68.0.0.0 # via -r requirements/typing.in -typing-extensions==4.3.0 +typing-extensions==4.6.3 # via mypy +watchdog==3.0.0 + # via -r requirements/typing.in diff --git a/setup.cfg b/setup.cfg deleted file mode 100644 index 2a1c2e4..0000000 --- a/setup.cfg +++ /dev/null @@ -1,130 +0,0 @@ -[metadata] -name = Werkzeug -version = attr: werkzeug.__version__ -url = https://palletsprojects.com/p/werkzeug/ -project_urls = - Donate = https://palletsprojects.com/donate - Documentation = https://werkzeug.palletsprojects.com/ - Changes = https://werkzeug.palletsprojects.com/changes/ - Source Code = https://github.com/pallets/werkzeug/ - Issue Tracker = https://github.com/pallets/werkzeug/issues/ - Twitter = https://twitter.com/PalletsTeam - Chat = https://discord.gg/pallets -license = BSD-3-Clause -author = Armin Ronacher -author_email = armin.ronacher@active-4.com -maintainer = Pallets -maintainer_email = contact@palletsprojects.com -description = The comprehensive WSGI web application library. -long_description = file: README.rst -long_description_content_type = text/x-rst -classifiers = - Development Status :: 5 - Production/Stable - Environment :: Web Environment - Intended Audience :: Developers - License :: OSI Approved :: BSD License - Operating System :: OS Independent - Programming Language :: Python - Topic :: Internet :: WWW/HTTP :: Dynamic Content - Topic :: Internet :: WWW/HTTP :: WSGI - Topic :: Internet :: WWW/HTTP :: WSGI :: Application - Topic :: Internet :: WWW/HTTP :: WSGI :: Middleware - Topic :: Software Development :: Libraries :: Application Frameworks - -[options] -packages = find: -package_dir = = src -include_package_data = True -python_requires = >= 3.7 -# Dependencies are in setup.py for GitHub's dependency graph. - -[options.packages.find] -where = src - -[tool:pytest] -testpaths = tests -filterwarnings = - error -markers = - dev_server: tests that start the dev server - -[coverage:run] -branch = True -source = - werkzeug - tests - -[coverage:paths] -source = - src - */site-packages - -[flake8] -# B = bugbear -# E = pycodestyle errors -# F = flake8 pyflakes -# W = pycodestyle warnings -# B9 = bugbear opinions -# ISC = implicit str concat -select = B, E, F, W, B9, ISC -ignore = - # slice notation whitespace, invalid - E203 - # import at top, too many circular import fixes - E402 - # line length, handled by bugbear B950 - E501 - # bare except, handled by bugbear B001 - E722 - # bin op line break, invalid - W503 -# up to 88 allowed by bugbear B950 -max-line-length = 80 -per-file-ignores = - # __init__ exports names - **/__init__.py: F401 - # LocalProxy assigns lambdas - src/werkzeug/local.py: E731 - -[mypy] -files = src/werkzeug -python_version = 3.7 -show_error_codes = True -allow_redefinition = True -disallow_subclassing_any = True -# disallow_untyped_calls = True -disallow_untyped_defs = True -disallow_incomplete_defs = True -no_implicit_optional = True -local_partial_types = True -no_implicit_reexport = True -strict_equality = True -warn_redundant_casts = True -warn_unused_configs = True -warn_unused_ignores = True -warn_return_any = True -# warn_unreachable = True - -[mypy-werkzeug.wrappers] -no_implicit_reexport = False - -[mypy-colorama.*] -ignore_missing_imports = True - -[mypy-cryptography.*] -ignore_missing_imports = True - -[mypy-eventlet.*] -ignore_missing_imports = True - -[mypy-gevent.*] -ignore_missing_imports = True - -[mypy-greenlet.*] -ignore_missing_imports = True - -[mypy-watchdog.*] -ignore_missing_imports = True - -[mypy-xprocess.*] -ignore_missing_imports = True diff --git a/setup.py b/setup.py deleted file mode 100644 index 37d75a5..0000000 --- a/setup.py +++ /dev/null @@ -1,9 +0,0 @@ -#!/usr/bin/env python -from setuptools import setup - -# Metadata goes in setup.cfg. These are here for GitHub's dependency graph. -setup( - name="Werkzeug", - install_requires=["MarkupSafe>=2.1.1"], - extras_require={"watchdog": ["watchdog"]}, -) diff --git a/src/werkzeug/__init__.py b/src/werkzeug/__init__.py index fd7f8d2..57cb753 100644 --- a/src/werkzeug/__init__.py +++ b/src/werkzeug/__init__.py @@ -1,6 +1,25 @@ +from __future__ import annotations + +import typing as t + from .serving import run_simple as run_simple from .test import Client as Client from .wrappers import Request as Request from .wrappers import Response as Response -__version__ = "2.2.2" + +def __getattr__(name: str) -> t.Any: + if name == "__version__": + import importlib.metadata + import warnings + + warnings.warn( + "The '__version__' attribute is deprecated and will be removed in" + " Werkzeug 3.1. Use feature detection or" + " 'importlib.metadata.version(\"werkzeug\")' instead.", + DeprecationWarning, + stacklevel=2, + ) + return importlib.metadata.version("werkzeug") + + raise AttributeError(name) diff --git a/src/werkzeug/_internal.py b/src/werkzeug/_internal.py index 4636647..70ab687 100644 --- a/src/werkzeug/_internal.py +++ b/src/werkzeug/_internal.py @@ -1,50 +1,17 @@ +from __future__ import annotations + import logging -import operator import re -import string import sys -import typing import typing as t -from datetime import date from datetime import datetime from datetime import timezone -from itertools import chain -from weakref import WeakKeyDictionary if t.TYPE_CHECKING: - from _typeshed.wsgi import StartResponse - from _typeshed.wsgi import WSGIApplication from _typeshed.wsgi import WSGIEnvironment - from .wrappers.request import Request # noqa: F401 + from .wrappers.request import Request -_logger: t.Optional[logging.Logger] = None -_signature_cache = WeakKeyDictionary() # type: ignore -_epoch_ord = date(1970, 1, 1).toordinal() -_legal_cookie_chars = frozenset( - c.encode("ascii") - for c in f"{string.ascii_letters}{string.digits}/=!#$%&'*+-.^_`|~:" -) - -_cookie_quoting_map = {b",": b"\\054", b";": b"\\073", b'"': b'\\"', b"\\": b"\\\\"} -for _i in chain(range(32), range(127, 256)): - _cookie_quoting_map[_i.to_bytes(1, sys.byteorder)] = f"\\{_i:03o}".encode("latin1") - -_octal_re = re.compile(rb"\\[0-3][0-7][0-7]") -_quote_re = re.compile(rb"[\\].") -_legal_cookie_chars_re = rb"[\w\d!#%&\'~_`><@,:/\$\*\+\-\.\^\|\)\(\?\}\{\=]" -_cookie_re = re.compile( - rb""" - (?P[^=;]+) - (?:\s*=\s* - (?P - "(?:[^\\"]|\\.)*" | - (?:.*?) - ) - )? - \s*; -""", - flags=re.VERBOSE, -) +_logger: logging.Logger | None = None class _Missing: @@ -58,110 +25,15 @@ class _Missing: _missing = _Missing() -@typing.overload -def _make_encode_wrapper(reference: str) -> t.Callable[[str], str]: - ... +def _wsgi_decoding_dance(s: str) -> str: + return s.encode("latin1").decode(errors="replace") -@typing.overload -def _make_encode_wrapper(reference: bytes) -> t.Callable[[str], bytes]: - ... +def _wsgi_encoding_dance(s: str) -> str: + return s.encode().decode("latin1") -def _make_encode_wrapper(reference: t.AnyStr) -> t.Callable[[str], t.AnyStr]: - """Create a function that will be called with a string argument. If - the reference is bytes, values will be encoded to bytes. - """ - if isinstance(reference, str): - return lambda x: x - - return operator.methodcaller("encode", "latin1") - - -def _check_str_tuple(value: t.Tuple[t.AnyStr, ...]) -> None: - """Ensure tuple items are all strings or all bytes.""" - if not value: - return - - item_type = str if isinstance(value[0], str) else bytes - - if any(not isinstance(item, item_type) for item in value): - raise TypeError(f"Cannot mix str and bytes arguments (got {value!r})") - - -_default_encoding = sys.getdefaultencoding() - - -def _to_bytes( - x: t.Union[str, bytes], charset: str = _default_encoding, errors: str = "strict" -) -> bytes: - if x is None or isinstance(x, bytes): - return x - - if isinstance(x, (bytearray, memoryview)): - return bytes(x) - - if isinstance(x, str): - return x.encode(charset, errors) - - raise TypeError("Expected bytes") - - -@typing.overload -def _to_str( # type: ignore - x: None, - charset: t.Optional[str] = ..., - errors: str = ..., - allow_none_charset: bool = ..., -) -> None: - ... - - -@typing.overload -def _to_str( - x: t.Any, - charset: t.Optional[str] = ..., - errors: str = ..., - allow_none_charset: bool = ..., -) -> str: - ... - - -def _to_str( - x: t.Optional[t.Any], - charset: t.Optional[str] = _default_encoding, - errors: str = "strict", - allow_none_charset: bool = False, -) -> t.Optional[t.Union[str, bytes]]: - if x is None or isinstance(x, str): - return x - - if not isinstance(x, (bytes, bytearray)): - return str(x) - - if charset is None: - if allow_none_charset: - return x - - return x.decode(charset, errors) # type: ignore - - -def _wsgi_decoding_dance( - s: str, charset: str = "utf-8", errors: str = "replace" -) -> str: - return s.encode("latin1").decode(charset, errors) - - -def _wsgi_encoding_dance( - s: str, charset: str = "utf-8", errors: str = "replace" -) -> str: - if isinstance(s, bytes): - return s.decode("latin1", errors) - - return s.encode(charset).decode("latin1", errors) - - -def _get_environ(obj: t.Union["WSGIEnvironment", "Request"]) -> "WSGIEnvironment": +def _get_environ(obj: WSGIEnvironment | Request) -> WSGIEnvironment: env = getattr(obj, "environ", obj) assert isinstance( env, dict @@ -224,17 +96,17 @@ def _log(type: str, message: str, *args: t.Any, **kwargs: t.Any) -> None: getattr(_logger, type)(message.rstrip(), *args, **kwargs) -@typing.overload +@t.overload def _dt_as_utc(dt: None) -> None: ... -@typing.overload +@t.overload def _dt_as_utc(dt: datetime) -> datetime: ... -def _dt_as_utc(dt: t.Optional[datetime]) -> t.Optional[datetime]: +def _dt_as_utc(dt: datetime | None) -> datetime | None: if dt is None: return dt @@ -257,11 +129,11 @@ class _DictAccessorProperty(t.Generic[_TAccessorValue]): def __init__( self, name: str, - default: t.Optional[_TAccessorValue] = None, - load_func: t.Optional[t.Callable[[str], _TAccessorValue]] = None, - dump_func: t.Optional[t.Callable[[_TAccessorValue], str]] = None, - read_only: t.Optional[bool] = None, - doc: t.Optional[str] = None, + default: _TAccessorValue | None = None, + load_func: t.Callable[[str], _TAccessorValue] | None = None, + dump_func: t.Callable[[_TAccessorValue], str] | None = None, + read_only: bool | None = None, + doc: str | None = None, ) -> None: self.name = name self.default = default @@ -274,19 +146,19 @@ class _DictAccessorProperty(t.Generic[_TAccessorValue]): def lookup(self, instance: t.Any) -> t.MutableMapping[str, t.Any]: raise NotImplementedError - @typing.overload + @t.overload def __get__( self, instance: None, owner: type - ) -> "_DictAccessorProperty[_TAccessorValue]": + ) -> _DictAccessorProperty[_TAccessorValue]: ... - @typing.overload + @t.overload def __get__(self, instance: t.Any, owner: type) -> _TAccessorValue: ... def __get__( - self, instance: t.Optional[t.Any], owner: type - ) -> t.Union[_TAccessorValue, "_DictAccessorProperty[_TAccessorValue]"]: + self, instance: t.Any | None, owner: type + ) -> _TAccessorValue | _DictAccessorProperty[_TAccessorValue]: if instance is None: return self @@ -324,225 +196,19 @@ class _DictAccessorProperty(t.Generic[_TAccessorValue]): return f"<{type(self).__name__} {self.name}>" -def _cookie_quote(b: bytes) -> bytes: - buf = bytearray() - all_legal = True - _lookup = _cookie_quoting_map.get - _push = buf.extend - - for char_int in b: - char = char_int.to_bytes(1, sys.byteorder) - if char not in _legal_cookie_chars: - all_legal = False - char = _lookup(char, char) - _push(char) - - if all_legal: - return bytes(buf) - return bytes(b'"' + buf + b'"') +_plain_int_re = re.compile(r"-?\d+", re.ASCII) -def _cookie_unquote(b: bytes) -> bytes: - if len(b) < 2: - return b - if b[:1] != b'"' or b[-1:] != b'"': - return b +def _plain_int(value: str) -> int: + """Parse an int only if it is only ASCII digits and ``-``. - b = b[1:-1] + This disallows ``+``, ``_``, and non-ASCII digits, which are accepted by ``int`` but + are not allowed in HTTP header values. - i = 0 - n = len(b) - rv = bytearray() - _push = rv.extend + Any leading or trailing whitespace is stripped + """ + value = value.strip() + if _plain_int_re.fullmatch(value) is None: + raise ValueError - while 0 <= i < n: - o_match = _octal_re.search(b, i) - q_match = _quote_re.search(b, i) - if not o_match and not q_match: - rv.extend(b[i:]) - break - j = k = -1 - if o_match: - j = o_match.start(0) - if q_match: - k = q_match.start(0) - if q_match and (not o_match or k < j): - _push(b[i:k]) - _push(b[k + 1 : k + 2]) - i = k + 2 - else: - _push(b[i:j]) - rv.append(int(b[j + 1 : j + 4], 8)) - i = j + 4 - - return bytes(rv) - - -def _cookie_parse_impl(b: bytes) -> t.Iterator[t.Tuple[bytes, bytes]]: - """Lowlevel cookie parsing facility that operates on bytes.""" - i = 0 - n = len(b) - - while i < n: - match = _cookie_re.search(b + b";", i) - if not match: - break - - key = match.group("key").strip() - value = match.group("val") or b"" - i = match.end(0) - - yield key, _cookie_unquote(value) - - -def _encode_idna(domain: str) -> bytes: - # If we're given bytes, make sure they fit into ASCII - if isinstance(domain, bytes): - domain.decode("ascii") - return domain - - # Otherwise check if it's already ascii, then return - try: - return domain.encode("ascii") - except UnicodeError: - pass - - # Otherwise encode each part separately - return b".".join(p.encode("idna") for p in domain.split(".")) - - -def _decode_idna(domain: t.Union[str, bytes]) -> str: - # If the input is a string try to encode it to ascii to do the idna - # decoding. If that fails because of a unicode error, then we - # already have a decoded idna domain. - if isinstance(domain, str): - try: - domain = domain.encode("ascii") - except UnicodeError: - return domain # type: ignore - - # Decode each part separately. If a part fails, try to decode it - # with ascii and silently ignore errors. This makes sense because - # the idna codec does not have error handling. - def decode_part(part: bytes) -> str: - try: - return part.decode("idna") - except UnicodeError: - return part.decode("ascii", "ignore") - - return ".".join(decode_part(p) for p in domain.split(b".")) - - -@typing.overload -def _make_cookie_domain(domain: None) -> None: - ... - - -@typing.overload -def _make_cookie_domain(domain: str) -> bytes: - ... - - -def _make_cookie_domain(domain: t.Optional[str]) -> t.Optional[bytes]: - if domain is None: - return None - domain = _encode_idna(domain) - if b":" in domain: - domain = domain.split(b":", 1)[0] - if b"." in domain: - return domain - raise ValueError( - "Setting 'domain' for a cookie on a server running locally (ex: " - "localhost) is not supported by complying browsers. You should " - "have something like: '127.0.0.1 localhost dev.localhost' on " - "your hosts file and then point your server to run on " - "'dev.localhost' and also set 'domain' for 'dev.localhost'" - ) - - -def _easteregg(app: t.Optional["WSGIApplication"] = None) -> "WSGIApplication": - """Like the name says. But who knows how it works?""" - - def bzzzzzzz(gyver: bytes) -> str: - import base64 - import zlib - - return zlib.decompress(base64.b64decode(gyver)).decode("ascii") - - gyver = "\n".join( - [ - x + (77 - len(x)) * " " - for x in bzzzzzzz( - b""" -eJyFlzuOJDkMRP06xRjymKgDJCDQStBYT8BCgK4gTwfQ2fcFs2a2FzvZk+hvlcRvRJD148efHt9m -9Xz94dRY5hGt1nrYcXx7us9qlcP9HHNh28rz8dZj+q4rynVFFPdlY4zH873NKCexrDM6zxxRymzz -4QIxzK4bth1PV7+uHn6WXZ5C4ka/+prFzx3zWLMHAVZb8RRUxtFXI5DTQ2n3Hi2sNI+HK43AOWSY -jmEzE4naFp58PdzhPMdslLVWHTGUVpSxImw+pS/D+JhzLfdS1j7PzUMxij+mc2U0I9zcbZ/HcZxc -q1QjvvcThMYFnp93agEx392ZdLJWXbi/Ca4Oivl4h/Y1ErEqP+lrg7Xa4qnUKu5UE9UUA4xeqLJ5 -jWlPKJvR2yhRI7xFPdzPuc6adXu6ovwXwRPXXnZHxlPtkSkqWHilsOrGrvcVWXgGP3daXomCj317 -8P2UOw/NnA0OOikZyFf3zZ76eN9QXNwYdD8f8/LdBRFg0BO3bB+Pe/+G8er8tDJv83XTkj7WeMBJ -v/rnAfdO51d6sFglfi8U7zbnr0u9tyJHhFZNXYfH8Iafv2Oa+DT6l8u9UYlajV/hcEgk1x8E8L/r -XJXl2SK+GJCxtnyhVKv6GFCEB1OO3f9YWAIEbwcRWv/6RPpsEzOkXURMN37J0PoCSYeBnJQd9Giu -LxYQJNlYPSo/iTQwgaihbART7Fcyem2tTSCcwNCs85MOOpJtXhXDe0E7zgZJkcxWTar/zEjdIVCk -iXy87FW6j5aGZhttDBoAZ3vnmlkx4q4mMmCdLtnHkBXFMCReqthSGkQ+MDXLLCpXwBs0t+sIhsDI -tjBB8MwqYQpLygZ56rRHHpw+OAVyGgaGRHWy2QfXez+ZQQTTBkmRXdV/A9LwH6XGZpEAZU8rs4pE -1R4FQ3Uwt8RKEtRc0/CrANUoes3EzM6WYcFyskGZ6UTHJWenBDS7h163Eo2bpzqxNE9aVgEM2CqI -GAJe9Yra4P5qKmta27VjzYdR04Vc7KHeY4vs61C0nbywFmcSXYjzBHdiEjraS7PGG2jHHTpJUMxN -Jlxr3pUuFvlBWLJGE3GcA1/1xxLcHmlO+LAXbhrXah1tD6Ze+uqFGdZa5FM+3eHcKNaEarutAQ0A -QMAZHV+ve6LxAwWnXbbSXEG2DmCX5ijeLCKj5lhVFBrMm+ryOttCAeFpUdZyQLAQkA06RLs56rzG -8MID55vqr/g64Qr/wqwlE0TVxgoiZhHrbY2h1iuuyUVg1nlkpDrQ7Vm1xIkI5XRKLedN9EjzVchu -jQhXcVkjVdgP2O99QShpdvXWoSwkp5uMwyjt3jiWCqWGSiaaPAzohjPanXVLbM3x0dNskJsaCEyz -DTKIs+7WKJD4ZcJGfMhLFBf6hlbnNkLEePF8Cx2o2kwmYF4+MzAxa6i+6xIQkswOqGO+3x9NaZX8 -MrZRaFZpLeVTYI9F/djY6DDVVs340nZGmwrDqTCiiqD5luj3OzwpmQCiQhdRYowUYEA3i1WWGwL4 -GCtSoO4XbIPFeKGU13XPkDf5IdimLpAvi2kVDVQbzOOa4KAXMFlpi/hV8F6IDe0Y2reg3PuNKT3i -RYhZqtkQZqSB2Qm0SGtjAw7RDwaM1roESC8HWiPxkoOy0lLTRFG39kvbLZbU9gFKFRvixDZBJmpi -Xyq3RE5lW00EJjaqwp/v3EByMSpVZYsEIJ4APaHmVtpGSieV5CALOtNUAzTBiw81GLgC0quyzf6c -NlWknzJeCsJ5fup2R4d8CYGN77mu5vnO1UqbfElZ9E6cR6zbHjgsr9ly18fXjZoPeDjPuzlWbFwS -pdvPkhntFvkc13qb9094LL5NrA3NIq3r9eNnop9DizWOqCEbyRBFJTHn6Tt3CG1o8a4HevYh0XiJ -sR0AVVHuGuMOIfbuQ/OKBkGRC6NJ4u7sbPX8bG/n5sNIOQ6/Y/BX3IwRlTSabtZpYLB85lYtkkgm -p1qXK3Du2mnr5INXmT/78KI12n11EFBkJHHp0wJyLe9MvPNUGYsf+170maayRoy2lURGHAIapSpQ -krEDuNoJCHNlZYhKpvw4mspVWxqo415n8cD62N9+EfHrAvqQnINStetek7RY2Urv8nxsnGaZfRr/ -nhXbJ6m/yl1LzYqscDZA9QHLNbdaSTTr+kFg3bC0iYbX/eQy0Bv3h4B50/SGYzKAXkCeOLI3bcAt -mj2Z/FM1vQWgDynsRwNvrWnJHlespkrp8+vO1jNaibm+PhqXPPv30YwDZ6jApe3wUjFQobghvW9p -7f2zLkGNv8b191cD/3vs9Q833z8t""" - ).splitlines() - ] - ) - - def easteregged( - environ: "WSGIEnvironment", start_response: "StartResponse" - ) -> t.Iterable[bytes]: - def injecting_start_response( - status: str, headers: t.List[t.Tuple[str, str]], exc_info: t.Any = None - ) -> t.Callable[[bytes], t.Any]: - headers.append(("X-Powered-By", "Werkzeug")) - return start_response(status, headers, exc_info) - - if app is not None and environ.get("QUERY_STRING") != "macgybarchakku": - return app(environ, injecting_start_response) - injecting_start_response("200 OK", [("Content-Type", "text/html")]) - return [ - f"""\ - - - -About Werkzeug - - - -

Werkzeug

-

the Swiss Army knife of Python web development.

-
{gyver}\n\n\n
- -""".encode( - "latin1" - ) - ] - - return easteregged + return int(value) diff --git a/src/werkzeug/_reloader.py b/src/werkzeug/_reloader.py index 57f3117..c868359 100644 --- a/src/werkzeug/_reloader.py +++ b/src/werkzeug/_reloader.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import fnmatch import os import subprocess @@ -20,7 +22,7 @@ prefix = {*_ignore_always, sys.prefix, sys.exec_prefix} if hasattr(sys, "real_prefix"): # virtualenv < 20 - prefix.add(sys.real_prefix) # type: ignore[attr-defined] + prefix.add(sys.real_prefix) _stat_ignore_scan = tuple(prefix) del prefix @@ -55,13 +57,13 @@ def _iter_module_paths() -> t.Iterator[str]: yield name -def _remove_by_pattern(paths: t.Set[str], exclude_patterns: t.Set[str]) -> None: +def _remove_by_pattern(paths: set[str], exclude_patterns: set[str]) -> None: for pattern in exclude_patterns: paths.difference_update(fnmatch.filter(paths, pattern)) def _find_stat_paths( - extra_files: t.Set[str], exclude_patterns: t.Set[str] + extra_files: set[str], exclude_patterns: set[str] ) -> t.Iterable[str]: """Find paths for the stat reloader to watch. Returns imported module files, Python files under non-system paths. Extra files and @@ -115,7 +117,7 @@ def _find_stat_paths( def _find_watchdog_paths( - extra_files: t.Set[str], exclude_patterns: t.Set[str] + extra_files: set[str], exclude_patterns: set[str] ) -> t.Iterable[str]: """Find paths for the stat reloader to watch. Looks at the same sources as the stat reloader, but watches everything under @@ -139,7 +141,7 @@ def _find_watchdog_paths( def _find_common_roots(paths: t.Iterable[str]) -> t.Iterable[str]: - root: t.Dict[str, dict] = {} + root: dict[str, dict] = {} for chunks in sorted((PurePath(x).parts for x in paths), key=len, reverse=True): node = root @@ -151,7 +153,7 @@ def _find_common_roots(paths: t.Iterable[str]) -> t.Iterable[str]: rv = set() - def _walk(node: t.Mapping[str, dict], path: t.Tuple[str, ...]) -> None: + def _walk(node: t.Mapping[str, dict], path: tuple[str, ...]) -> None: for prefix, child in node.items(): _walk(child, path + (prefix,)) @@ -162,10 +164,15 @@ def _find_common_roots(paths: t.Iterable[str]) -> t.Iterable[str]: return rv -def _get_args_for_reloading() -> t.List[str]: +def _get_args_for_reloading() -> list[str]: """Determine how the script was executed, and return the args needed to execute it again in a new process. """ + if sys.version_info >= (3, 10): + # sys.orig_argv, added in Python 3.10, contains the exact args used to invoke + # Python. Still replace argv[0] with sys.executable for accuracy. + return [sys.executable, *sys.orig_argv[1:]] + rv = [sys.executable] py_script = sys.argv[0] args = sys.argv[1:] @@ -221,15 +228,15 @@ class ReloaderLoop: def __init__( self, - extra_files: t.Optional[t.Iterable[str]] = None, - exclude_patterns: t.Optional[t.Iterable[str]] = None, - interval: t.Union[int, float] = 1, + extra_files: t.Iterable[str] | None = None, + exclude_patterns: t.Iterable[str] | None = None, + interval: int | float = 1, ) -> None: - self.extra_files: t.Set[str] = {os.path.abspath(x) for x in extra_files or ()} - self.exclude_patterns: t.Set[str] = set(exclude_patterns or ()) + self.extra_files: set[str] = {os.path.abspath(x) for x in extra_files or ()} + self.exclude_patterns: set[str] = set(exclude_patterns or ()) self.interval = interval - def __enter__(self) -> "ReloaderLoop": + def __enter__(self) -> ReloaderLoop: """Do any setup, then run one step of the watch to populate the initial filesystem state. """ @@ -281,7 +288,7 @@ class StatReloaderLoop(ReloaderLoop): name = "stat" def __enter__(self) -> ReloaderLoop: - self.mtimes: t.Dict[str, float] = {} + self.mtimes: dict[str, float] = {} return super().__enter__() def run_step(self) -> None: @@ -305,15 +312,20 @@ class WatchdogReloaderLoop(ReloaderLoop): def __init__(self, *args: t.Any, **kwargs: t.Any) -> None: from watchdog.observers import Observer from watchdog.events import PatternMatchingEventHandler + from watchdog.events import EVENT_TYPE_OPENED + from watchdog.events import FileModifiedEvent super().__init__(*args, **kwargs) trigger_reload = self.trigger_reload - class EventHandler(PatternMatchingEventHandler): # type: ignore - def on_any_event(self, event): # type: ignore + class EventHandler(PatternMatchingEventHandler): + def on_any_event(self, event: FileModifiedEvent): # type: ignore + if event.event_type == EVENT_TYPE_OPENED: + return + trigger_reload(event.src_path) - reloader_name = Observer.__name__.lower() + reloader_name = Observer.__name__.lower() # type: ignore[attr-defined] if reloader_name.endswith("observer"): reloader_name = reloader_name[:-8] @@ -343,7 +355,7 @@ class WatchdogReloaderLoop(ReloaderLoop): self.log_reload(filename) def __enter__(self) -> ReloaderLoop: - self.watches: t.Dict[str, t.Any] = {} + self.watches: dict[str, t.Any] = {} self.observer.start() return super().__enter__() @@ -382,7 +394,7 @@ class WatchdogReloaderLoop(ReloaderLoop): self.observer.unschedule(watch) -reloader_loops: t.Dict[str, t.Type[ReloaderLoop]] = { +reloader_loops: dict[str, type[ReloaderLoop]] = { "stat": StatReloaderLoop, "watchdog": WatchdogReloaderLoop, } @@ -416,9 +428,9 @@ def ensure_echo_on() -> None: def run_with_reloader( main_func: t.Callable[[], None], - extra_files: t.Optional[t.Iterable[str]] = None, - exclude_patterns: t.Optional[t.Iterable[str]] = None, - interval: t.Union[int, float] = 1, + extra_files: t.Iterable[str] | None = None, + exclude_patterns: t.Iterable[str] | None = None, + interval: int | float = 1, reloader_type: str = "auto", ) -> None: """Run the given function in an independent Python interpreter.""" diff --git a/src/werkzeug/datastructures.py b/src/werkzeug/datastructures.py deleted file mode 100644 index 43ee8c7..0000000 --- a/src/werkzeug/datastructures.py +++ /dev/null @@ -1,3040 +0,0 @@ -import base64 -import codecs -import mimetypes -import os -import re -from collections.abc import Collection -from collections.abc import MutableSet -from copy import deepcopy -from io import BytesIO -from itertools import repeat -from os import fspath - -from . import exceptions -from ._internal import _missing - - -def is_immutable(self): - raise TypeError(f"{type(self).__name__!r} objects are immutable") - - -def iter_multi_items(mapping): - """Iterates over the items of a mapping yielding keys and values - without dropping any from more complex structures. - """ - if isinstance(mapping, MultiDict): - yield from mapping.items(multi=True) - elif isinstance(mapping, dict): - for key, value in mapping.items(): - if isinstance(value, (tuple, list)): - for v in value: - yield key, v - else: - yield key, value - else: - yield from mapping - - -class ImmutableListMixin: - """Makes a :class:`list` immutable. - - .. versionadded:: 0.5 - - :private: - """ - - _hash_cache = None - - def __hash__(self): - if self._hash_cache is not None: - return self._hash_cache - rv = self._hash_cache = hash(tuple(self)) - return rv - - def __reduce_ex__(self, protocol): - return type(self), (list(self),) - - def __delitem__(self, key): - is_immutable(self) - - def __iadd__(self, other): - is_immutable(self) - - def __imul__(self, other): - is_immutable(self) - - def __setitem__(self, key, value): - is_immutable(self) - - def append(self, item): - is_immutable(self) - - def remove(self, item): - is_immutable(self) - - def extend(self, iterable): - is_immutable(self) - - def insert(self, pos, value): - is_immutable(self) - - def pop(self, index=-1): - is_immutable(self) - - def reverse(self): - is_immutable(self) - - def sort(self, key=None, reverse=False): - is_immutable(self) - - -class ImmutableList(ImmutableListMixin, list): - """An immutable :class:`list`. - - .. versionadded:: 0.5 - - :private: - """ - - def __repr__(self): - return f"{type(self).__name__}({list.__repr__(self)})" - - -class ImmutableDictMixin: - """Makes a :class:`dict` immutable. - - .. versionadded:: 0.5 - - :private: - """ - - _hash_cache = None - - @classmethod - def fromkeys(cls, keys, value=None): - instance = super().__new__(cls) - instance.__init__(zip(keys, repeat(value))) - return instance - - def __reduce_ex__(self, protocol): - return type(self), (dict(self),) - - def _iter_hashitems(self): - return self.items() - - def __hash__(self): - if self._hash_cache is not None: - return self._hash_cache - rv = self._hash_cache = hash(frozenset(self._iter_hashitems())) - return rv - - def setdefault(self, key, default=None): - is_immutable(self) - - def update(self, *args, **kwargs): - is_immutable(self) - - def pop(self, key, default=None): - is_immutable(self) - - def popitem(self): - is_immutable(self) - - def __setitem__(self, key, value): - is_immutable(self) - - def __delitem__(self, key): - is_immutable(self) - - def clear(self): - is_immutable(self) - - -class ImmutableMultiDictMixin(ImmutableDictMixin): - """Makes a :class:`MultiDict` immutable. - - .. versionadded:: 0.5 - - :private: - """ - - def __reduce_ex__(self, protocol): - return type(self), (list(self.items(multi=True)),) - - def _iter_hashitems(self): - return self.items(multi=True) - - def add(self, key, value): - is_immutable(self) - - def popitemlist(self): - is_immutable(self) - - def poplist(self, key): - is_immutable(self) - - def setlist(self, key, new_list): - is_immutable(self) - - def setlistdefault(self, key, default_list=None): - is_immutable(self) - - -def _calls_update(name): - def oncall(self, *args, **kw): - rv = getattr(super(UpdateDictMixin, self), name)(*args, **kw) - - if self.on_update is not None: - self.on_update(self) - - return rv - - oncall.__name__ = name - return oncall - - -class UpdateDictMixin(dict): - """Makes dicts call `self.on_update` on modifications. - - .. versionadded:: 0.5 - - :private: - """ - - on_update = None - - def setdefault(self, key, default=None): - modified = key not in self - rv = super().setdefault(key, default) - if modified and self.on_update is not None: - self.on_update(self) - return rv - - def pop(self, key, default=_missing): - modified = key in self - if default is _missing: - rv = super().pop(key) - else: - rv = super().pop(key, default) - if modified and self.on_update is not None: - self.on_update(self) - return rv - - __setitem__ = _calls_update("__setitem__") - __delitem__ = _calls_update("__delitem__") - clear = _calls_update("clear") - popitem = _calls_update("popitem") - update = _calls_update("update") - - -class TypeConversionDict(dict): - """Works like a regular dict but the :meth:`get` method can perform - type conversions. :class:`MultiDict` and :class:`CombinedMultiDict` - are subclasses of this class and provide the same feature. - - .. versionadded:: 0.5 - """ - - def get(self, key, default=None, type=None): - """Return the default value if the requested data doesn't exist. - If `type` is provided and is a callable it should convert the value, - return it or raise a :exc:`ValueError` if that is not possible. In - this case the function will return the default as if the value was not - found: - - >>> d = TypeConversionDict(foo='42', bar='blub') - >>> d.get('foo', type=int) - 42 - >>> d.get('bar', -1, type=int) - -1 - - :param key: The key to be looked up. - :param default: The default value to be returned if the key can't - be looked up. If not further specified `None` is - returned. - :param type: A callable that is used to cast the value in the - :class:`MultiDict`. If a :exc:`ValueError` is raised - by this callable the default value is returned. - """ - try: - rv = self[key] - except KeyError: - return default - if type is not None: - try: - rv = type(rv) - except ValueError: - rv = default - return rv - - -class ImmutableTypeConversionDict(ImmutableDictMixin, TypeConversionDict): - """Works like a :class:`TypeConversionDict` but does not support - modifications. - - .. versionadded:: 0.5 - """ - - def copy(self): - """Return a shallow mutable copy of this object. Keep in mind that - the standard library's :func:`copy` function is a no-op for this class - like for any other python immutable type (eg: :class:`tuple`). - """ - return TypeConversionDict(self) - - def __copy__(self): - return self - - -class MultiDict(TypeConversionDict): - """A :class:`MultiDict` is a dictionary subclass customized to deal with - multiple values for the same key which is for example used by the parsing - functions in the wrappers. This is necessary because some HTML form - elements pass multiple values for the same key. - - :class:`MultiDict` implements all standard dictionary methods. - Internally, it saves all values for a key as a list, but the standard dict - access methods will only return the first value for a key. If you want to - gain access to the other values, too, you have to use the `list` methods as - explained below. - - Basic Usage: - - >>> d = MultiDict([('a', 'b'), ('a', 'c')]) - >>> d - MultiDict([('a', 'b'), ('a', 'c')]) - >>> d['a'] - 'b' - >>> d.getlist('a') - ['b', 'c'] - >>> 'a' in d - True - - It behaves like a normal dict thus all dict functions will only return the - first value when multiple values for one key are found. - - From Werkzeug 0.3 onwards, the `KeyError` raised by this class is also a - subclass of the :exc:`~exceptions.BadRequest` HTTP exception and will - render a page for a ``400 BAD REQUEST`` if caught in a catch-all for HTTP - exceptions. - - A :class:`MultiDict` can be constructed from an iterable of - ``(key, value)`` tuples, a dict, a :class:`MultiDict` or from Werkzeug 0.2 - onwards some keyword parameters. - - :param mapping: the initial value for the :class:`MultiDict`. Either a - regular dict, an iterable of ``(key, value)`` tuples - or `None`. - """ - - def __init__(self, mapping=None): - if isinstance(mapping, MultiDict): - dict.__init__(self, ((k, l[:]) for k, l in mapping.lists())) - elif isinstance(mapping, dict): - tmp = {} - for key, value in mapping.items(): - if isinstance(value, (tuple, list)): - if len(value) == 0: - continue - value = list(value) - else: - value = [value] - tmp[key] = value - dict.__init__(self, tmp) - else: - tmp = {} - for key, value in mapping or (): - tmp.setdefault(key, []).append(value) - dict.__init__(self, tmp) - - def __getstate__(self): - return dict(self.lists()) - - def __setstate__(self, value): - dict.clear(self) - dict.update(self, value) - - def __iter__(self): - # Work around https://bugs.python.org/issue43246. - # (`return super().__iter__()` also works here, which makes this look - # even more like it should be a no-op, yet it isn't.) - return dict.__iter__(self) - - def __getitem__(self, key): - """Return the first data value for this key; - raises KeyError if not found. - - :param key: The key to be looked up. - :raise KeyError: if the key does not exist. - """ - - if key in self: - lst = dict.__getitem__(self, key) - if len(lst) > 0: - return lst[0] - raise exceptions.BadRequestKeyError(key) - - def __setitem__(self, key, value): - """Like :meth:`add` but removes an existing key first. - - :param key: the key for the value. - :param value: the value to set. - """ - dict.__setitem__(self, key, [value]) - - def add(self, key, value): - """Adds a new value for the key. - - .. versionadded:: 0.6 - - :param key: the key for the value. - :param value: the value to add. - """ - dict.setdefault(self, key, []).append(value) - - def getlist(self, key, type=None): - """Return the list of items for a given key. If that key is not in the - `MultiDict`, the return value will be an empty list. Just like `get`, - `getlist` accepts a `type` parameter. All items will be converted - with the callable defined there. - - :param key: The key to be looked up. - :param type: A callable that is used to cast the value in the - :class:`MultiDict`. If a :exc:`ValueError` is raised - by this callable the value will be removed from the list. - :return: a :class:`list` of all the values for the key. - """ - try: - rv = dict.__getitem__(self, key) - except KeyError: - return [] - if type is None: - return list(rv) - result = [] - for item in rv: - try: - result.append(type(item)) - except ValueError: - pass - return result - - def setlist(self, key, new_list): - """Remove the old values for a key and add new ones. Note that the list - you pass the values in will be shallow-copied before it is inserted in - the dictionary. - - >>> d = MultiDict() - >>> d.setlist('foo', ['1', '2']) - >>> d['foo'] - '1' - >>> d.getlist('foo') - ['1', '2'] - - :param key: The key for which the values are set. - :param new_list: An iterable with the new values for the key. Old values - are removed first. - """ - dict.__setitem__(self, key, list(new_list)) - - def setdefault(self, key, default=None): - """Returns the value for the key if it is in the dict, otherwise it - returns `default` and sets that value for `key`. - - :param key: The key to be looked up. - :param default: The default value to be returned if the key is not - in the dict. If not further specified it's `None`. - """ - if key not in self: - self[key] = default - else: - default = self[key] - return default - - def setlistdefault(self, key, default_list=None): - """Like `setdefault` but sets multiple values. The list returned - is not a copy, but the list that is actually used internally. This - means that you can put new values into the dict by appending items - to the list: - - >>> d = MultiDict({"foo": 1}) - >>> d.setlistdefault("foo").extend([2, 3]) - >>> d.getlist("foo") - [1, 2, 3] - - :param key: The key to be looked up. - :param default_list: An iterable of default values. It is either copied - (in case it was a list) or converted into a list - before returned. - :return: a :class:`list` - """ - if key not in self: - default_list = list(default_list or ()) - dict.__setitem__(self, key, default_list) - else: - default_list = dict.__getitem__(self, key) - return default_list - - def items(self, multi=False): - """Return an iterator of ``(key, value)`` pairs. - - :param multi: If set to `True` the iterator returned will have a pair - for each value of each key. Otherwise it will only - contain pairs for the first value of each key. - """ - for key, values in dict.items(self): - if multi: - for value in values: - yield key, value - else: - yield key, values[0] - - def lists(self): - """Return a iterator of ``(key, values)`` pairs, where values is the list - of all values associated with the key.""" - for key, values in dict.items(self): - yield key, list(values) - - def values(self): - """Returns an iterator of the first value on every key's value list.""" - for values in dict.values(self): - yield values[0] - - def listvalues(self): - """Return an iterator of all values associated with a key. Zipping - :meth:`keys` and this is the same as calling :meth:`lists`: - - >>> d = MultiDict({"foo": [1, 2, 3]}) - >>> zip(d.keys(), d.listvalues()) == d.lists() - True - """ - return dict.values(self) - - def copy(self): - """Return a shallow copy of this object.""" - return self.__class__(self) - - def deepcopy(self, memo=None): - """Return a deep copy of this object.""" - return self.__class__(deepcopy(self.to_dict(flat=False), memo)) - - def to_dict(self, flat=True): - """Return the contents as regular dict. If `flat` is `True` the - returned dict will only have the first item present, if `flat` is - `False` all values will be returned as lists. - - :param flat: If set to `False` the dict returned will have lists - with all the values in it. Otherwise it will only - contain the first value for each key. - :return: a :class:`dict` - """ - if flat: - return dict(self.items()) - return dict(self.lists()) - - def update(self, mapping): - """update() extends rather than replaces existing key lists: - - >>> a = MultiDict({'x': 1}) - >>> b = MultiDict({'x': 2, 'y': 3}) - >>> a.update(b) - >>> a - MultiDict([('y', 3), ('x', 1), ('x', 2)]) - - If the value list for a key in ``other_dict`` is empty, no new values - will be added to the dict and the key will not be created: - - >>> x = {'empty_list': []} - >>> y = MultiDict() - >>> y.update(x) - >>> y - MultiDict([]) - """ - for key, value in iter_multi_items(mapping): - MultiDict.add(self, key, value) - - def pop(self, key, default=_missing): - """Pop the first item for a list on the dict. Afterwards the - key is removed from the dict, so additional values are discarded: - - >>> d = MultiDict({"foo": [1, 2, 3]}) - >>> d.pop("foo") - 1 - >>> "foo" in d - False - - :param key: the key to pop. - :param default: if provided the value to return if the key was - not in the dictionary. - """ - try: - lst = dict.pop(self, key) - - if len(lst) == 0: - raise exceptions.BadRequestKeyError(key) - - return lst[0] - except KeyError: - if default is not _missing: - return default - - raise exceptions.BadRequestKeyError(key) from None - - def popitem(self): - """Pop an item from the dict.""" - try: - item = dict.popitem(self) - - if len(item[1]) == 0: - raise exceptions.BadRequestKeyError(item[0]) - - return (item[0], item[1][0]) - except KeyError as e: - raise exceptions.BadRequestKeyError(e.args[0]) from None - - def poplist(self, key): - """Pop the list for a key from the dict. If the key is not in the dict - an empty list is returned. - - .. versionchanged:: 0.5 - If the key does no longer exist a list is returned instead of - raising an error. - """ - return dict.pop(self, key, []) - - def popitemlist(self): - """Pop a ``(key, list)`` tuple from the dict.""" - try: - return dict.popitem(self) - except KeyError as e: - raise exceptions.BadRequestKeyError(e.args[0]) from None - - def __copy__(self): - return self.copy() - - def __deepcopy__(self, memo): - return self.deepcopy(memo=memo) - - def __repr__(self): - return f"{type(self).__name__}({list(self.items(multi=True))!r})" - - -class _omd_bucket: - """Wraps values in the :class:`OrderedMultiDict`. This makes it - possible to keep an order over multiple different keys. It requires - a lot of extra memory and slows down access a lot, but makes it - possible to access elements in O(1) and iterate in O(n). - """ - - __slots__ = ("prev", "key", "value", "next") - - def __init__(self, omd, key, value): - self.prev = omd._last_bucket - self.key = key - self.value = value - self.next = None - - if omd._first_bucket is None: - omd._first_bucket = self - if omd._last_bucket is not None: - omd._last_bucket.next = self - omd._last_bucket = self - - def unlink(self, omd): - if self.prev: - self.prev.next = self.next - if self.next: - self.next.prev = self.prev - if omd._first_bucket is self: - omd._first_bucket = self.next - if omd._last_bucket is self: - omd._last_bucket = self.prev - - -class OrderedMultiDict(MultiDict): - """Works like a regular :class:`MultiDict` but preserves the - order of the fields. To convert the ordered multi dict into a - list you can use the :meth:`items` method and pass it ``multi=True``. - - In general an :class:`OrderedMultiDict` is an order of magnitude - slower than a :class:`MultiDict`. - - .. admonition:: note - - Due to a limitation in Python you cannot convert an ordered - multi dict into a regular dict by using ``dict(multidict)``. - Instead you have to use the :meth:`to_dict` method, otherwise - the internal bucket objects are exposed. - """ - - def __init__(self, mapping=None): - dict.__init__(self) - self._first_bucket = self._last_bucket = None - if mapping is not None: - OrderedMultiDict.update(self, mapping) - - def __eq__(self, other): - if not isinstance(other, MultiDict): - return NotImplemented - if isinstance(other, OrderedMultiDict): - iter1 = iter(self.items(multi=True)) - iter2 = iter(other.items(multi=True)) - try: - for k1, v1 in iter1: - k2, v2 = next(iter2) - if k1 != k2 or v1 != v2: - return False - except StopIteration: - return False - try: - next(iter2) - except StopIteration: - return True - return False - if len(self) != len(other): - return False - for key, values in self.lists(): - if other.getlist(key) != values: - return False - return True - - __hash__ = None - - def __reduce_ex__(self, protocol): - return type(self), (list(self.items(multi=True)),) - - def __getstate__(self): - return list(self.items(multi=True)) - - def __setstate__(self, values): - dict.clear(self) - for key, value in values: - self.add(key, value) - - def __getitem__(self, key): - if key in self: - return dict.__getitem__(self, key)[0].value - raise exceptions.BadRequestKeyError(key) - - def __setitem__(self, key, value): - self.poplist(key) - self.add(key, value) - - def __delitem__(self, key): - self.pop(key) - - def keys(self): - return (key for key, value in self.items()) - - def __iter__(self): - return iter(self.keys()) - - def values(self): - return (value for key, value in self.items()) - - def items(self, multi=False): - ptr = self._first_bucket - if multi: - while ptr is not None: - yield ptr.key, ptr.value - ptr = ptr.next - else: - returned_keys = set() - while ptr is not None: - if ptr.key not in returned_keys: - returned_keys.add(ptr.key) - yield ptr.key, ptr.value - ptr = ptr.next - - def lists(self): - returned_keys = set() - ptr = self._first_bucket - while ptr is not None: - if ptr.key not in returned_keys: - yield ptr.key, self.getlist(ptr.key) - returned_keys.add(ptr.key) - ptr = ptr.next - - def listvalues(self): - for _key, values in self.lists(): - yield values - - def add(self, key, value): - dict.setdefault(self, key, []).append(_omd_bucket(self, key, value)) - - def getlist(self, key, type=None): - try: - rv = dict.__getitem__(self, key) - except KeyError: - return [] - if type is None: - return [x.value for x in rv] - result = [] - for item in rv: - try: - result.append(type(item.value)) - except ValueError: - pass - return result - - def setlist(self, key, new_list): - self.poplist(key) - for value in new_list: - self.add(key, value) - - def setlistdefault(self, key, default_list=None): - raise TypeError("setlistdefault is unsupported for ordered multi dicts") - - def update(self, mapping): - for key, value in iter_multi_items(mapping): - OrderedMultiDict.add(self, key, value) - - def poplist(self, key): - buckets = dict.pop(self, key, ()) - for bucket in buckets: - bucket.unlink(self) - return [x.value for x in buckets] - - def pop(self, key, default=_missing): - try: - buckets = dict.pop(self, key) - except KeyError: - if default is not _missing: - return default - - raise exceptions.BadRequestKeyError(key) from None - - for bucket in buckets: - bucket.unlink(self) - - return buckets[0].value - - def popitem(self): - try: - key, buckets = dict.popitem(self) - except KeyError as e: - raise exceptions.BadRequestKeyError(e.args[0]) from None - - for bucket in buckets: - bucket.unlink(self) - - return key, buckets[0].value - - def popitemlist(self): - try: - key, buckets = dict.popitem(self) - except KeyError as e: - raise exceptions.BadRequestKeyError(e.args[0]) from None - - for bucket in buckets: - bucket.unlink(self) - - return key, [x.value for x in buckets] - - -def _options_header_vkw(value, kw): - return http.dump_options_header( - value, {k.replace("_", "-"): v for k, v in kw.items()} - ) - - -def _unicodify_header_value(value): - if isinstance(value, bytes): - value = value.decode("latin-1") - if not isinstance(value, str): - value = str(value) - return value - - -class Headers: - """An object that stores some headers. It has a dict-like interface, - but is ordered, can store the same key multiple times, and iterating - yields ``(key, value)`` pairs instead of only keys. - - This data structure is useful if you want a nicer way to handle WSGI - headers which are stored as tuples in a list. - - From Werkzeug 0.3 onwards, the :exc:`KeyError` raised by this class is - also a subclass of the :class:`~exceptions.BadRequest` HTTP exception - and will render a page for a ``400 BAD REQUEST`` if caught in a - catch-all for HTTP exceptions. - - Headers is mostly compatible with the Python :class:`wsgiref.headers.Headers` - class, with the exception of `__getitem__`. :mod:`wsgiref` will return - `None` for ``headers['missing']``, whereas :class:`Headers` will raise - a :class:`KeyError`. - - To create a new ``Headers`` object, pass it a list, dict, or - other ``Headers`` object with default values. These values are - validated the same way values added later are. - - :param defaults: The list of default values for the :class:`Headers`. - - .. versionchanged:: 2.1.0 - Default values are validated the same as values added later. - - .. versionchanged:: 0.9 - This data structure now stores unicode values similar to how the - multi dicts do it. The main difference is that bytes can be set as - well which will automatically be latin1 decoded. - - .. versionchanged:: 0.9 - The :meth:`linked` function was removed without replacement as it - was an API that does not support the changes to the encoding model. - """ - - def __init__(self, defaults=None): - self._list = [] - if defaults is not None: - self.extend(defaults) - - def __getitem__(self, key, _get_mode=False): - if not _get_mode: - if isinstance(key, int): - return self._list[key] - elif isinstance(key, slice): - return self.__class__(self._list[key]) - if not isinstance(key, str): - raise exceptions.BadRequestKeyError(key) - ikey = key.lower() - for k, v in self._list: - if k.lower() == ikey: - return v - # micro optimization: if we are in get mode we will catch that - # exception one stack level down so we can raise a standard - # key error instead of our special one. - if _get_mode: - raise KeyError() - raise exceptions.BadRequestKeyError(key) - - def __eq__(self, other): - def lowered(item): - return (item[0].lower(),) + item[1:] - - return other.__class__ is self.__class__ and set( - map(lowered, other._list) - ) == set(map(lowered, self._list)) - - __hash__ = None - - def get(self, key, default=None, type=None, as_bytes=False): - """Return the default value if the requested data doesn't exist. - If `type` is provided and is a callable it should convert the value, - return it or raise a :exc:`ValueError` if that is not possible. In - this case the function will return the default as if the value was not - found: - - >>> d = Headers([('Content-Length', '42')]) - >>> d.get('Content-Length', type=int) - 42 - - .. versionadded:: 0.9 - Added support for `as_bytes`. - - :param key: The key to be looked up. - :param default: The default value to be returned if the key can't - be looked up. If not further specified `None` is - returned. - :param type: A callable that is used to cast the value in the - :class:`Headers`. If a :exc:`ValueError` is raised - by this callable the default value is returned. - :param as_bytes: return bytes instead of strings. - """ - try: - rv = self.__getitem__(key, _get_mode=True) - except KeyError: - return default - if as_bytes: - rv = rv.encode("latin1") - if type is None: - return rv - try: - return type(rv) - except ValueError: - return default - - def getlist(self, key, type=None, as_bytes=False): - """Return the list of items for a given key. If that key is not in the - :class:`Headers`, the return value will be an empty list. Just like - :meth:`get`, :meth:`getlist` accepts a `type` parameter. All items will - be converted with the callable defined there. - - .. versionadded:: 0.9 - Added support for `as_bytes`. - - :param key: The key to be looked up. - :param type: A callable that is used to cast the value in the - :class:`Headers`. If a :exc:`ValueError` is raised - by this callable the value will be removed from the list. - :return: a :class:`list` of all the values for the key. - :param as_bytes: return bytes instead of strings. - """ - ikey = key.lower() - result = [] - for k, v in self: - if k.lower() == ikey: - if as_bytes: - v = v.encode("latin1") - if type is not None: - try: - v = type(v) - except ValueError: - continue - result.append(v) - return result - - def get_all(self, name): - """Return a list of all the values for the named field. - - This method is compatible with the :mod:`wsgiref` - :meth:`~wsgiref.headers.Headers.get_all` method. - """ - return self.getlist(name) - - def items(self, lower=False): - for key, value in self: - if lower: - key = key.lower() - yield key, value - - def keys(self, lower=False): - for key, _ in self.items(lower): - yield key - - def values(self): - for _, value in self.items(): - yield value - - def extend(self, *args, **kwargs): - """Extend headers in this object with items from another object - containing header items as well as keyword arguments. - - To replace existing keys instead of extending, use - :meth:`update` instead. - - If provided, the first argument can be another :class:`Headers` - object, a :class:`MultiDict`, :class:`dict`, or iterable of - pairs. - - .. versionchanged:: 1.0 - Support :class:`MultiDict`. Allow passing ``kwargs``. - """ - if len(args) > 1: - raise TypeError(f"update expected at most 1 arguments, got {len(args)}") - - if args: - for key, value in iter_multi_items(args[0]): - self.add(key, value) - - for key, value in iter_multi_items(kwargs): - self.add(key, value) - - def __delitem__(self, key, _index_operation=True): - if _index_operation and isinstance(key, (int, slice)): - del self._list[key] - return - key = key.lower() - new = [] - for k, v in self._list: - if k.lower() != key: - new.append((k, v)) - self._list[:] = new - - def remove(self, key): - """Remove a key. - - :param key: The key to be removed. - """ - return self.__delitem__(key, _index_operation=False) - - def pop(self, key=None, default=_missing): - """Removes and returns a key or index. - - :param key: The key to be popped. If this is an integer the item at - that position is removed, if it's a string the value for - that key is. If the key is omitted or `None` the last - item is removed. - :return: an item. - """ - if key is None: - return self._list.pop() - if isinstance(key, int): - return self._list.pop(key) - try: - rv = self[key] - self.remove(key) - except KeyError: - if default is not _missing: - return default - raise - return rv - - def popitem(self): - """Removes a key or index and returns a (key, value) item.""" - return self.pop() - - def __contains__(self, key): - """Check if a key is present.""" - try: - self.__getitem__(key, _get_mode=True) - except KeyError: - return False - return True - - def __iter__(self): - """Yield ``(key, value)`` tuples.""" - return iter(self._list) - - def __len__(self): - return len(self._list) - - def add(self, _key, _value, **kw): - """Add a new header tuple to the list. - - Keyword arguments can specify additional parameters for the header - value, with underscores converted to dashes:: - - >>> d = Headers() - >>> d.add('Content-Type', 'text/plain') - >>> d.add('Content-Disposition', 'attachment', filename='foo.png') - - The keyword argument dumping uses :func:`dump_options_header` - behind the scenes. - - .. versionadded:: 0.4.1 - keyword arguments were added for :mod:`wsgiref` compatibility. - """ - if kw: - _value = _options_header_vkw(_value, kw) - _key = _unicodify_header_value(_key) - _value = _unicodify_header_value(_value) - self._validate_value(_value) - self._list.append((_key, _value)) - - def _validate_value(self, value): - if not isinstance(value, str): - raise TypeError("Value should be a string.") - if "\n" in value or "\r" in value: - raise ValueError( - "Detected newline in header value. This is " - "a potential security problem" - ) - - def add_header(self, _key, _value, **_kw): - """Add a new header tuple to the list. - - An alias for :meth:`add` for compatibility with the :mod:`wsgiref` - :meth:`~wsgiref.headers.Headers.add_header` method. - """ - self.add(_key, _value, **_kw) - - def clear(self): - """Clears all headers.""" - del self._list[:] - - def set(self, _key, _value, **kw): - """Remove all header tuples for `key` and add a new one. The newly - added key either appears at the end of the list if there was no - entry or replaces the first one. - - Keyword arguments can specify additional parameters for the header - value, with underscores converted to dashes. See :meth:`add` for - more information. - - .. versionchanged:: 0.6.1 - :meth:`set` now accepts the same arguments as :meth:`add`. - - :param key: The key to be inserted. - :param value: The value to be inserted. - """ - if kw: - _value = _options_header_vkw(_value, kw) - _key = _unicodify_header_value(_key) - _value = _unicodify_header_value(_value) - self._validate_value(_value) - if not self._list: - self._list.append((_key, _value)) - return - listiter = iter(self._list) - ikey = _key.lower() - for idx, (old_key, _old_value) in enumerate(listiter): - if old_key.lower() == ikey: - # replace first occurrence - self._list[idx] = (_key, _value) - break - else: - self._list.append((_key, _value)) - return - self._list[idx + 1 :] = [t for t in listiter if t[0].lower() != ikey] - - def setlist(self, key, values): - """Remove any existing values for a header and add new ones. - - :param key: The header key to set. - :param values: An iterable of values to set for the key. - - .. versionadded:: 1.0 - """ - if values: - values_iter = iter(values) - self.set(key, next(values_iter)) - - for value in values_iter: - self.add(key, value) - else: - self.remove(key) - - def setdefault(self, key, default): - """Return the first value for the key if it is in the headers, - otherwise set the header to the value given by ``default`` and - return that. - - :param key: The header key to get. - :param default: The value to set for the key if it is not in the - headers. - """ - if key in self: - return self[key] - - self.set(key, default) - return default - - def setlistdefault(self, key, default): - """Return the list of values for the key if it is in the - headers, otherwise set the header to the list of values given - by ``default`` and return that. - - Unlike :meth:`MultiDict.setlistdefault`, modifying the returned - list will not affect the headers. - - :param key: The header key to get. - :param default: An iterable of values to set for the key if it - is not in the headers. - - .. versionadded:: 1.0 - """ - if key not in self: - self.setlist(key, default) - - return self.getlist(key) - - def __setitem__(self, key, value): - """Like :meth:`set` but also supports index/slice based setting.""" - if isinstance(key, (slice, int)): - if isinstance(key, int): - value = [value] - value = [ - (_unicodify_header_value(k), _unicodify_header_value(v)) - for (k, v) in value - ] - for (_, v) in value: - self._validate_value(v) - if isinstance(key, int): - self._list[key] = value[0] - else: - self._list[key] = value - else: - self.set(key, value) - - def update(self, *args, **kwargs): - """Replace headers in this object with items from another - headers object and keyword arguments. - - To extend existing keys instead of replacing, use :meth:`extend` - instead. - - If provided, the first argument can be another :class:`Headers` - object, a :class:`MultiDict`, :class:`dict`, or iterable of - pairs. - - .. versionadded:: 1.0 - """ - if len(args) > 1: - raise TypeError(f"update expected at most 1 arguments, got {len(args)}") - - if args: - mapping = args[0] - - if isinstance(mapping, (Headers, MultiDict)): - for key in mapping.keys(): - self.setlist(key, mapping.getlist(key)) - elif isinstance(mapping, dict): - for key, value in mapping.items(): - if isinstance(value, (list, tuple)): - self.setlist(key, value) - else: - self.set(key, value) - else: - for key, value in mapping: - self.set(key, value) - - for key, value in kwargs.items(): - if isinstance(value, (list, tuple)): - self.setlist(key, value) - else: - self.set(key, value) - - def to_wsgi_list(self): - """Convert the headers into a list suitable for WSGI. - - :return: list - """ - return list(self) - - def copy(self): - return self.__class__(self._list) - - def __copy__(self): - return self.copy() - - def __str__(self): - """Returns formatted headers suitable for HTTP transmission.""" - strs = [] - for key, value in self.to_wsgi_list(): - strs.append(f"{key}: {value}") - strs.append("\r\n") - return "\r\n".join(strs) - - def __repr__(self): - return f"{type(self).__name__}({list(self)!r})" - - -class ImmutableHeadersMixin: - """Makes a :class:`Headers` immutable. We do not mark them as - hashable though since the only usecase for this datastructure - in Werkzeug is a view on a mutable structure. - - .. versionadded:: 0.5 - - :private: - """ - - def __delitem__(self, key, **kwargs): - is_immutable(self) - - def __setitem__(self, key, value): - is_immutable(self) - - def set(self, _key, _value, **kw): - is_immutable(self) - - def setlist(self, key, values): - is_immutable(self) - - def add(self, _key, _value, **kw): - is_immutable(self) - - def add_header(self, _key, _value, **_kw): - is_immutable(self) - - def remove(self, key): - is_immutable(self) - - def extend(self, *args, **kwargs): - is_immutable(self) - - def update(self, *args, **kwargs): - is_immutable(self) - - def insert(self, pos, value): - is_immutable(self) - - def pop(self, key=None, default=_missing): - is_immutable(self) - - def popitem(self): - is_immutable(self) - - def setdefault(self, key, default): - is_immutable(self) - - def setlistdefault(self, key, default): - is_immutable(self) - - -class EnvironHeaders(ImmutableHeadersMixin, Headers): - """Read only version of the headers from a WSGI environment. This - provides the same interface as `Headers` and is constructed from - a WSGI environment. - - From Werkzeug 0.3 onwards, the `KeyError` raised by this class is also a - subclass of the :exc:`~exceptions.BadRequest` HTTP exception and will - render a page for a ``400 BAD REQUEST`` if caught in a catch-all for - HTTP exceptions. - """ - - def __init__(self, environ): - self.environ = environ - - def __eq__(self, other): - return self.environ is other.environ - - __hash__ = None - - def __getitem__(self, key, _get_mode=False): - # _get_mode is a no-op for this class as there is no index but - # used because get() calls it. - if not isinstance(key, str): - raise KeyError(key) - key = key.upper().replace("-", "_") - if key in ("CONTENT_TYPE", "CONTENT_LENGTH"): - return _unicodify_header_value(self.environ[key]) - return _unicodify_header_value(self.environ[f"HTTP_{key}"]) - - def __len__(self): - # the iter is necessary because otherwise list calls our - # len which would call list again and so forth. - return len(list(iter(self))) - - def __iter__(self): - for key, value in self.environ.items(): - if key.startswith("HTTP_") and key not in ( - "HTTP_CONTENT_TYPE", - "HTTP_CONTENT_LENGTH", - ): - yield ( - key[5:].replace("_", "-").title(), - _unicodify_header_value(value), - ) - elif key in ("CONTENT_TYPE", "CONTENT_LENGTH") and value: - yield (key.replace("_", "-").title(), _unicodify_header_value(value)) - - def copy(self): - raise TypeError(f"cannot create {type(self).__name__!r} copies") - - -class CombinedMultiDict(ImmutableMultiDictMixin, MultiDict): - """A read only :class:`MultiDict` that you can pass multiple :class:`MultiDict` - instances as sequence and it will combine the return values of all wrapped - dicts: - - >>> from werkzeug.datastructures import CombinedMultiDict, MultiDict - >>> post = MultiDict([('foo', 'bar')]) - >>> get = MultiDict([('blub', 'blah')]) - >>> combined = CombinedMultiDict([get, post]) - >>> combined['foo'] - 'bar' - >>> combined['blub'] - 'blah' - - This works for all read operations and will raise a `TypeError` for - methods that usually change data which isn't possible. - - From Werkzeug 0.3 onwards, the `KeyError` raised by this class is also a - subclass of the :exc:`~exceptions.BadRequest` HTTP exception and will - render a page for a ``400 BAD REQUEST`` if caught in a catch-all for HTTP - exceptions. - """ - - def __reduce_ex__(self, protocol): - return type(self), (self.dicts,) - - def __init__(self, dicts=None): - self.dicts = list(dicts) or [] - - @classmethod - def fromkeys(cls, keys, value=None): - raise TypeError(f"cannot create {cls.__name__!r} instances by fromkeys") - - def __getitem__(self, key): - for d in self.dicts: - if key in d: - return d[key] - raise exceptions.BadRequestKeyError(key) - - def get(self, key, default=None, type=None): - for d in self.dicts: - if key in d: - if type is not None: - try: - return type(d[key]) - except ValueError: - continue - return d[key] - return default - - def getlist(self, key, type=None): - rv = [] - for d in self.dicts: - rv.extend(d.getlist(key, type)) - return rv - - def _keys_impl(self): - """This function exists so __len__ can be implemented more efficiently, - saving one list creation from an iterator. - """ - rv = set() - rv.update(*self.dicts) - return rv - - def keys(self): - return self._keys_impl() - - def __iter__(self): - return iter(self.keys()) - - def items(self, multi=False): - found = set() - for d in self.dicts: - for key, value in d.items(multi): - if multi: - yield key, value - elif key not in found: - found.add(key) - yield key, value - - def values(self): - for _key, value in self.items(): - yield value - - def lists(self): - rv = {} - for d in self.dicts: - for key, values in d.lists(): - rv.setdefault(key, []).extend(values) - return list(rv.items()) - - def listvalues(self): - return (x[1] for x in self.lists()) - - def copy(self): - """Return a shallow mutable copy of this object. - - This returns a :class:`MultiDict` representing the data at the - time of copying. The copy will no longer reflect changes to the - wrapped dicts. - - .. versionchanged:: 0.15 - Return a mutable :class:`MultiDict`. - """ - return MultiDict(self) - - def to_dict(self, flat=True): - """Return the contents as regular dict. If `flat` is `True` the - returned dict will only have the first item present, if `flat` is - `False` all values will be returned as lists. - - :param flat: If set to `False` the dict returned will have lists - with all the values in it. Otherwise it will only - contain the first item for each key. - :return: a :class:`dict` - """ - if flat: - return dict(self.items()) - - return dict(self.lists()) - - def __len__(self): - return len(self._keys_impl()) - - def __contains__(self, key): - for d in self.dicts: - if key in d: - return True - return False - - def __repr__(self): - return f"{type(self).__name__}({self.dicts!r})" - - -class FileMultiDict(MultiDict): - """A special :class:`MultiDict` that has convenience methods to add - files to it. This is used for :class:`EnvironBuilder` and generally - useful for unittesting. - - .. versionadded:: 0.5 - """ - - def add_file(self, name, file, filename=None, content_type=None): - """Adds a new file to the dict. `file` can be a file name or - a :class:`file`-like or a :class:`FileStorage` object. - - :param name: the name of the field. - :param file: a filename or :class:`file`-like object - :param filename: an optional filename - :param content_type: an optional content type - """ - if isinstance(file, FileStorage): - value = file - else: - if isinstance(file, str): - if filename is None: - filename = file - file = open(file, "rb") - if filename and content_type is None: - content_type = ( - mimetypes.guess_type(filename)[0] or "application/octet-stream" - ) - value = FileStorage(file, filename, name, content_type) - - self.add(name, value) - - -class ImmutableDict(ImmutableDictMixin, dict): - """An immutable :class:`dict`. - - .. versionadded:: 0.5 - """ - - def __repr__(self): - return f"{type(self).__name__}({dict.__repr__(self)})" - - def copy(self): - """Return a shallow mutable copy of this object. Keep in mind that - the standard library's :func:`copy` function is a no-op for this class - like for any other python immutable type (eg: :class:`tuple`). - """ - return dict(self) - - def __copy__(self): - return self - - -class ImmutableMultiDict(ImmutableMultiDictMixin, MultiDict): - """An immutable :class:`MultiDict`. - - .. versionadded:: 0.5 - """ - - def copy(self): - """Return a shallow mutable copy of this object. Keep in mind that - the standard library's :func:`copy` function is a no-op for this class - like for any other python immutable type (eg: :class:`tuple`). - """ - return MultiDict(self) - - def __copy__(self): - return self - - -class ImmutableOrderedMultiDict(ImmutableMultiDictMixin, OrderedMultiDict): - """An immutable :class:`OrderedMultiDict`. - - .. versionadded:: 0.6 - """ - - def _iter_hashitems(self): - return enumerate(self.items(multi=True)) - - def copy(self): - """Return a shallow mutable copy of this object. Keep in mind that - the standard library's :func:`copy` function is a no-op for this class - like for any other python immutable type (eg: :class:`tuple`). - """ - return OrderedMultiDict(self) - - def __copy__(self): - return self - - -class Accept(ImmutableList): - """An :class:`Accept` object is just a list subclass for lists of - ``(value, quality)`` tuples. It is automatically sorted by specificity - and quality. - - All :class:`Accept` objects work similar to a list but provide extra - functionality for working with the data. Containment checks are - normalized to the rules of that header: - - >>> a = CharsetAccept([('ISO-8859-1', 1), ('utf-8', 0.7)]) - >>> a.best - 'ISO-8859-1' - >>> 'iso-8859-1' in a - True - >>> 'UTF8' in a - True - >>> 'utf7' in a - False - - To get the quality for an item you can use normal item lookup: - - >>> print a['utf-8'] - 0.7 - >>> a['utf7'] - 0 - - .. versionchanged:: 0.5 - :class:`Accept` objects are forced immutable now. - - .. versionchanged:: 1.0.0 - :class:`Accept` internal values are no longer ordered - alphabetically for equal quality tags. Instead the initial - order is preserved. - - """ - - def __init__(self, values=()): - if values is None: - list.__init__(self) - self.provided = False - elif isinstance(values, Accept): - self.provided = values.provided - list.__init__(self, values) - else: - self.provided = True - values = sorted( - values, key=lambda x: (self._specificity(x[0]), x[1]), reverse=True - ) - list.__init__(self, values) - - def _specificity(self, value): - """Returns a tuple describing the value's specificity.""" - return (value != "*",) - - def _value_matches(self, value, item): - """Check if a value matches a given accept item.""" - return item == "*" or item.lower() == value.lower() - - def __getitem__(self, key): - """Besides index lookup (getting item n) you can also pass it a string - to get the quality for the item. If the item is not in the list, the - returned quality is ``0``. - """ - if isinstance(key, str): - return self.quality(key) - return list.__getitem__(self, key) - - def quality(self, key): - """Returns the quality of the key. - - .. versionadded:: 0.6 - In previous versions you had to use the item-lookup syntax - (eg: ``obj[key]`` instead of ``obj.quality(key)``) - """ - for item, quality in self: - if self._value_matches(key, item): - return quality - return 0 - - def __contains__(self, value): - for item, _quality in self: - if self._value_matches(value, item): - return True - return False - - def __repr__(self): - pairs_str = ", ".join(f"({x!r}, {y})" for x, y in self) - return f"{type(self).__name__}([{pairs_str}])" - - def index(self, key): - """Get the position of an entry or raise :exc:`ValueError`. - - :param key: The key to be looked up. - - .. versionchanged:: 0.5 - This used to raise :exc:`IndexError`, which was inconsistent - with the list API. - """ - if isinstance(key, str): - for idx, (item, _quality) in enumerate(self): - if self._value_matches(key, item): - return idx - raise ValueError(key) - return list.index(self, key) - - def find(self, key): - """Get the position of an entry or return -1. - - :param key: The key to be looked up. - """ - try: - return self.index(key) - except ValueError: - return -1 - - def values(self): - """Iterate over all values.""" - for item in self: - yield item[0] - - def to_header(self): - """Convert the header set into an HTTP header string.""" - result = [] - for value, quality in self: - if quality != 1: - value = f"{value};q={quality}" - result.append(value) - return ",".join(result) - - def __str__(self): - return self.to_header() - - def _best_single_match(self, match): - for client_item, quality in self: - if self._value_matches(match, client_item): - # self is sorted by specificity descending, we can exit - return client_item, quality - return None - - def best_match(self, matches, default=None): - """Returns the best match from a list of possible matches based - on the specificity and quality of the client. If two items have the - same quality and specificity, the one is returned that comes first. - - :param matches: a list of matches to check for - :param default: the value that is returned if none match - """ - result = default - best_quality = -1 - best_specificity = (-1,) - for server_item in matches: - match = self._best_single_match(server_item) - if not match: - continue - client_item, quality = match - specificity = self._specificity(client_item) - if quality <= 0 or quality < best_quality: - continue - # better quality or same quality but more specific => better match - if quality > best_quality or specificity > best_specificity: - result = server_item - best_quality = quality - best_specificity = specificity - return result - - @property - def best(self): - """The best match as value.""" - if self: - return self[0][0] - - -_mime_split_re = re.compile(r"/|(?:\s*;\s*)") - - -def _normalize_mime(value): - return _mime_split_re.split(value.lower()) - - -class MIMEAccept(Accept): - """Like :class:`Accept` but with special methods and behavior for - mimetypes. - """ - - def _specificity(self, value): - return tuple(x != "*" for x in _mime_split_re.split(value)) - - def _value_matches(self, value, item): - # item comes from the client, can't match if it's invalid. - if "/" not in item: - return False - - # value comes from the application, tell the developer when it - # doesn't look valid. - if "/" not in value: - raise ValueError(f"invalid mimetype {value!r}") - - # Split the match value into type, subtype, and a sorted list of parameters. - normalized_value = _normalize_mime(value) - value_type, value_subtype = normalized_value[:2] - value_params = sorted(normalized_value[2:]) - - # "*/*" is the only valid value that can start with "*". - if value_type == "*" and value_subtype != "*": - raise ValueError(f"invalid mimetype {value!r}") - - # Split the accept item into type, subtype, and parameters. - normalized_item = _normalize_mime(item) - item_type, item_subtype = normalized_item[:2] - item_params = sorted(normalized_item[2:]) - - # "*/not-*" from the client is invalid, can't match. - if item_type == "*" and item_subtype != "*": - return False - - return ( - (item_type == "*" and item_subtype == "*") - or (value_type == "*" and value_subtype == "*") - ) or ( - item_type == value_type - and ( - item_subtype == "*" - or value_subtype == "*" - or (item_subtype == value_subtype and item_params == value_params) - ) - ) - - @property - def accept_html(self): - """True if this object accepts HTML.""" - return ( - "text/html" in self or "application/xhtml+xml" in self or self.accept_xhtml - ) - - @property - def accept_xhtml(self): - """True if this object accepts XHTML.""" - return "application/xhtml+xml" in self or "application/xml" in self - - @property - def accept_json(self): - """True if this object accepts JSON.""" - return "application/json" in self - - -_locale_delim_re = re.compile(r"[_-]") - - -def _normalize_lang(value): - """Process a language tag for matching.""" - return _locale_delim_re.split(value.lower()) - - -class LanguageAccept(Accept): - """Like :class:`Accept` but with normalization for language tags.""" - - def _value_matches(self, value, item): - return item == "*" or _normalize_lang(value) == _normalize_lang(item) - - def best_match(self, matches, default=None): - """Given a list of supported values, finds the best match from - the list of accepted values. - - Language tags are normalized for the purpose of matching, but - are returned unchanged. - - If no exact match is found, this will fall back to matching - the first subtag (primary language only), first with the - accepted values then with the match values. This partial is not - applied to any other language subtags. - - The default is returned if no exact or fallback match is found. - - :param matches: A list of supported languages to find a match. - :param default: The value that is returned if none match. - """ - # Look for an exact match first. If a client accepts "en-US", - # "en-US" is a valid match at this point. - result = super().best_match(matches) - - if result is not None: - return result - - # Fall back to accepting primary tags. If a client accepts - # "en-US", "en" is a valid match at this point. Need to use - # re.split to account for 2 or 3 letter codes. - fallback = Accept( - [(_locale_delim_re.split(item[0], 1)[0], item[1]) for item in self] - ) - result = fallback.best_match(matches) - - if result is not None: - return result - - # Fall back to matching primary tags. If the client accepts - # "en", "en-US" is a valid match at this point. - fallback_matches = [_locale_delim_re.split(item, 1)[0] for item in matches] - result = super().best_match(fallback_matches) - - # Return a value from the original match list. Find the first - # original value that starts with the matched primary tag. - if result is not None: - return next(item for item in matches if item.startswith(result)) - - return default - - -class CharsetAccept(Accept): - """Like :class:`Accept` but with normalization for charsets.""" - - def _value_matches(self, value, item): - def _normalize(name): - try: - return codecs.lookup(name).name - except LookupError: - return name.lower() - - return item == "*" or _normalize(value) == _normalize(item) - - -def cache_control_property(key, empty, type): - """Return a new property object for a cache header. Useful if you - want to add support for a cache extension in a subclass. - - .. versionchanged:: 2.0 - Renamed from ``cache_property``. - """ - return property( - lambda x: x._get_cache_value(key, empty, type), - lambda x, v: x._set_cache_value(key, v, type), - lambda x: x._del_cache_value(key), - f"accessor for {key!r}", - ) - - -class _CacheControl(UpdateDictMixin, dict): - """Subclass of a dict that stores values for a Cache-Control header. It - has accessors for all the cache-control directives specified in RFC 2616. - The class does not differentiate between request and response directives. - - Because the cache-control directives in the HTTP header use dashes the - python descriptors use underscores for that. - - To get a header of the :class:`CacheControl` object again you can convert - the object into a string or call the :meth:`to_header` method. If you plan - to subclass it and add your own items have a look at the sourcecode for - that class. - - .. versionchanged:: 2.1.0 - Setting int properties such as ``max_age`` will convert the - value to an int. - - .. versionchanged:: 0.4 - - Setting `no_cache` or `private` to boolean `True` will set the implicit - none-value which is ``*``: - - >>> cc = ResponseCacheControl() - >>> cc.no_cache = True - >>> cc - - >>> cc.no_cache - '*' - >>> cc.no_cache = None - >>> cc - - - In versions before 0.5 the behavior documented here affected the now - no longer existing `CacheControl` class. - """ - - no_cache = cache_control_property("no-cache", "*", None) - no_store = cache_control_property("no-store", None, bool) - max_age = cache_control_property("max-age", -1, int) - no_transform = cache_control_property("no-transform", None, None) - - def __init__(self, values=(), on_update=None): - dict.__init__(self, values or ()) - self.on_update = on_update - self.provided = values is not None - - def _get_cache_value(self, key, empty, type): - """Used internally by the accessor properties.""" - if type is bool: - return key in self - if key in self: - value = self[key] - if value is None: - return empty - elif type is not None: - try: - value = type(value) - except ValueError: - pass - return value - return None - - def _set_cache_value(self, key, value, type): - """Used internally by the accessor properties.""" - if type is bool: - if value: - self[key] = None - else: - self.pop(key, None) - else: - if value is None: - self.pop(key, None) - elif value is True: - self[key] = None - else: - if type is not None: - self[key] = type(value) - else: - self[key] = value - - def _del_cache_value(self, key): - """Used internally by the accessor properties.""" - if key in self: - del self[key] - - def to_header(self): - """Convert the stored values into a cache control header.""" - return http.dump_header(self) - - def __str__(self): - return self.to_header() - - def __repr__(self): - kv_str = " ".join(f"{k}={v!r}" for k, v in sorted(self.items())) - return f"<{type(self).__name__} {kv_str}>" - - cache_property = staticmethod(cache_control_property) - - -class RequestCacheControl(ImmutableDictMixin, _CacheControl): - """A cache control for requests. This is immutable and gives access - to all the request-relevant cache control headers. - - To get a header of the :class:`RequestCacheControl` object again you can - convert the object into a string or call the :meth:`to_header` method. If - you plan to subclass it and add your own items have a look at the sourcecode - for that class. - - .. versionchanged:: 2.1.0 - Setting int properties such as ``max_age`` will convert the - value to an int. - - .. versionadded:: 0.5 - In previous versions a `CacheControl` class existed that was used - both for request and response. - """ - - max_stale = cache_control_property("max-stale", "*", int) - min_fresh = cache_control_property("min-fresh", "*", int) - only_if_cached = cache_control_property("only-if-cached", None, bool) - - -class ResponseCacheControl(_CacheControl): - """A cache control for responses. Unlike :class:`RequestCacheControl` - this is mutable and gives access to response-relevant cache control - headers. - - To get a header of the :class:`ResponseCacheControl` object again you can - convert the object into a string or call the :meth:`to_header` method. If - you plan to subclass it and add your own items have a look at the sourcecode - for that class. - - .. versionchanged:: 2.1.1 - ``s_maxage`` converts the value to an int. - - .. versionchanged:: 2.1.0 - Setting int properties such as ``max_age`` will convert the - value to an int. - - .. versionadded:: 0.5 - In previous versions a `CacheControl` class existed that was used - both for request and response. - """ - - public = cache_control_property("public", None, bool) - private = cache_control_property("private", "*", None) - must_revalidate = cache_control_property("must-revalidate", None, bool) - proxy_revalidate = cache_control_property("proxy-revalidate", None, bool) - s_maxage = cache_control_property("s-maxage", None, int) - immutable = cache_control_property("immutable", None, bool) - - -def csp_property(key): - """Return a new property object for a content security policy header. - Useful if you want to add support for a csp extension in a - subclass. - """ - return property( - lambda x: x._get_value(key), - lambda x, v: x._set_value(key, v), - lambda x: x._del_value(key), - f"accessor for {key!r}", - ) - - -class ContentSecurityPolicy(UpdateDictMixin, dict): - """Subclass of a dict that stores values for a Content Security Policy - header. It has accessors for all the level 3 policies. - - Because the csp directives in the HTTP header use dashes the - python descriptors use underscores for that. - - To get a header of the :class:`ContentSecuirtyPolicy` object again - you can convert the object into a string or call the - :meth:`to_header` method. If you plan to subclass it and add your - own items have a look at the sourcecode for that class. - - .. versionadded:: 1.0.0 - Support for Content Security Policy headers was added. - - """ - - base_uri = csp_property("base-uri") - child_src = csp_property("child-src") - connect_src = csp_property("connect-src") - default_src = csp_property("default-src") - font_src = csp_property("font-src") - form_action = csp_property("form-action") - frame_ancestors = csp_property("frame-ancestors") - frame_src = csp_property("frame-src") - img_src = csp_property("img-src") - manifest_src = csp_property("manifest-src") - media_src = csp_property("media-src") - navigate_to = csp_property("navigate-to") - object_src = csp_property("object-src") - prefetch_src = csp_property("prefetch-src") - plugin_types = csp_property("plugin-types") - report_to = csp_property("report-to") - report_uri = csp_property("report-uri") - sandbox = csp_property("sandbox") - script_src = csp_property("script-src") - script_src_attr = csp_property("script-src-attr") - script_src_elem = csp_property("script-src-elem") - style_src = csp_property("style-src") - style_src_attr = csp_property("style-src-attr") - style_src_elem = csp_property("style-src-elem") - worker_src = csp_property("worker-src") - - def __init__(self, values=(), on_update=None): - dict.__init__(self, values or ()) - self.on_update = on_update - self.provided = values is not None - - def _get_value(self, key): - """Used internally by the accessor properties.""" - return self.get(key) - - def _set_value(self, key, value): - """Used internally by the accessor properties.""" - if value is None: - self.pop(key, None) - else: - self[key] = value - - def _del_value(self, key): - """Used internally by the accessor properties.""" - if key in self: - del self[key] - - def to_header(self): - """Convert the stored values into a cache control header.""" - return http.dump_csp_header(self) - - def __str__(self): - return self.to_header() - - def __repr__(self): - kv_str = " ".join(f"{k}={v!r}" for k, v in sorted(self.items())) - return f"<{type(self).__name__} {kv_str}>" - - -class CallbackDict(UpdateDictMixin, dict): - """A dict that calls a function passed every time something is changed. - The function is passed the dict instance. - """ - - def __init__(self, initial=None, on_update=None): - dict.__init__(self, initial or ()) - self.on_update = on_update - - def __repr__(self): - return f"<{type(self).__name__} {dict.__repr__(self)}>" - - -class HeaderSet(MutableSet): - """Similar to the :class:`ETags` class this implements a set-like structure. - Unlike :class:`ETags` this is case insensitive and used for vary, allow, and - content-language headers. - - If not constructed using the :func:`parse_set_header` function the - instantiation works like this: - - >>> hs = HeaderSet(['foo', 'bar', 'baz']) - >>> hs - HeaderSet(['foo', 'bar', 'baz']) - """ - - def __init__(self, headers=None, on_update=None): - self._headers = list(headers or ()) - self._set = {x.lower() for x in self._headers} - self.on_update = on_update - - def add(self, header): - """Add a new header to the set.""" - self.update((header,)) - - def remove(self, header): - """Remove a header from the set. This raises an :exc:`KeyError` if the - header is not in the set. - - .. versionchanged:: 0.5 - In older versions a :exc:`IndexError` was raised instead of a - :exc:`KeyError` if the object was missing. - - :param header: the header to be removed. - """ - key = header.lower() - if key not in self._set: - raise KeyError(header) - self._set.remove(key) - for idx, key in enumerate(self._headers): - if key.lower() == header: - del self._headers[idx] - break - if self.on_update is not None: - self.on_update(self) - - def update(self, iterable): - """Add all the headers from the iterable to the set. - - :param iterable: updates the set with the items from the iterable. - """ - inserted_any = False - for header in iterable: - key = header.lower() - if key not in self._set: - self._headers.append(header) - self._set.add(key) - inserted_any = True - if inserted_any and self.on_update is not None: - self.on_update(self) - - def discard(self, header): - """Like :meth:`remove` but ignores errors. - - :param header: the header to be discarded. - """ - try: - self.remove(header) - except KeyError: - pass - - def find(self, header): - """Return the index of the header in the set or return -1 if not found. - - :param header: the header to be looked up. - """ - header = header.lower() - for idx, item in enumerate(self._headers): - if item.lower() == header: - return idx - return -1 - - def index(self, header): - """Return the index of the header in the set or raise an - :exc:`IndexError`. - - :param header: the header to be looked up. - """ - rv = self.find(header) - if rv < 0: - raise IndexError(header) - return rv - - def clear(self): - """Clear the set.""" - self._set.clear() - del self._headers[:] - if self.on_update is not None: - self.on_update(self) - - def as_set(self, preserve_casing=False): - """Return the set as real python set type. When calling this, all - the items are converted to lowercase and the ordering is lost. - - :param preserve_casing: if set to `True` the items in the set returned - will have the original case like in the - :class:`HeaderSet`, otherwise they will - be lowercase. - """ - if preserve_casing: - return set(self._headers) - return set(self._set) - - def to_header(self): - """Convert the header set into an HTTP header string.""" - return ", ".join(map(http.quote_header_value, self._headers)) - - def __getitem__(self, idx): - return self._headers[idx] - - def __delitem__(self, idx): - rv = self._headers.pop(idx) - self._set.remove(rv.lower()) - if self.on_update is not None: - self.on_update(self) - - def __setitem__(self, idx, value): - old = self._headers[idx] - self._set.remove(old.lower()) - self._headers[idx] = value - self._set.add(value.lower()) - if self.on_update is not None: - self.on_update(self) - - def __contains__(self, header): - return header.lower() in self._set - - def __len__(self): - return len(self._set) - - def __iter__(self): - return iter(self._headers) - - def __bool__(self): - return bool(self._set) - - def __str__(self): - return self.to_header() - - def __repr__(self): - return f"{type(self).__name__}({self._headers!r})" - - -class ETags(Collection): - """A set that can be used to check if one etag is present in a collection - of etags. - """ - - def __init__(self, strong_etags=None, weak_etags=None, star_tag=False): - if not star_tag and strong_etags: - self._strong = frozenset(strong_etags) - else: - self._strong = frozenset() - - self._weak = frozenset(weak_etags or ()) - self.star_tag = star_tag - - def as_set(self, include_weak=False): - """Convert the `ETags` object into a python set. Per default all the - weak etags are not part of this set.""" - rv = set(self._strong) - if include_weak: - rv.update(self._weak) - return rv - - def is_weak(self, etag): - """Check if an etag is weak.""" - return etag in self._weak - - def is_strong(self, etag): - """Check if an etag is strong.""" - return etag in self._strong - - def contains_weak(self, etag): - """Check if an etag is part of the set including weak and strong tags.""" - return self.is_weak(etag) or self.contains(etag) - - def contains(self, etag): - """Check if an etag is part of the set ignoring weak tags. - It is also possible to use the ``in`` operator. - """ - if self.star_tag: - return True - return self.is_strong(etag) - - def contains_raw(self, etag): - """When passed a quoted tag it will check if this tag is part of the - set. If the tag is weak it is checked against weak and strong tags, - otherwise strong only.""" - etag, weak = http.unquote_etag(etag) - if weak: - return self.contains_weak(etag) - return self.contains(etag) - - def to_header(self): - """Convert the etags set into a HTTP header string.""" - if self.star_tag: - return "*" - return ", ".join( - [f'"{x}"' for x in self._strong] + [f'W/"{x}"' for x in self._weak] - ) - - def __call__(self, etag=None, data=None, include_weak=False): - if [etag, data].count(None) != 1: - raise TypeError("either tag or data required, but at least one") - if etag is None: - etag = http.generate_etag(data) - if include_weak: - if etag in self._weak: - return True - return etag in self._strong - - def __bool__(self): - return bool(self.star_tag or self._strong or self._weak) - - def __str__(self): - return self.to_header() - - def __len__(self): - return len(self._strong) - - def __iter__(self): - return iter(self._strong) - - def __contains__(self, etag): - return self.contains(etag) - - def __repr__(self): - return f"<{type(self).__name__} {str(self)!r}>" - - -class IfRange: - """Very simple object that represents the `If-Range` header in parsed - form. It will either have neither a etag or date or one of either but - never both. - - .. versionadded:: 0.7 - """ - - def __init__(self, etag=None, date=None): - #: The etag parsed and unquoted. Ranges always operate on strong - #: etags so the weakness information is not necessary. - self.etag = etag - #: The date in parsed format or `None`. - self.date = date - - def to_header(self): - """Converts the object back into an HTTP header.""" - if self.date is not None: - return http.http_date(self.date) - if self.etag is not None: - return http.quote_etag(self.etag) - return "" - - def __str__(self): - return self.to_header() - - def __repr__(self): - return f"<{type(self).__name__} {str(self)!r}>" - - -class Range: - """Represents a ``Range`` header. All methods only support only - bytes as the unit. Stores a list of ranges if given, but the methods - only work if only one range is provided. - - :raise ValueError: If the ranges provided are invalid. - - .. versionchanged:: 0.15 - The ranges passed in are validated. - - .. versionadded:: 0.7 - """ - - def __init__(self, units, ranges): - #: The units of this range. Usually "bytes". - self.units = units - #: A list of ``(begin, end)`` tuples for the range header provided. - #: The ranges are non-inclusive. - self.ranges = ranges - - for start, end in ranges: - if start is None or (end is not None and (start < 0 or start >= end)): - raise ValueError(f"{(start, end)} is not a valid range.") - - def range_for_length(self, length): - """If the range is for bytes, the length is not None and there is - exactly one range and it is satisfiable it returns a ``(start, stop)`` - tuple, otherwise `None`. - """ - if self.units != "bytes" or length is None or len(self.ranges) != 1: - return None - start, end = self.ranges[0] - if end is None: - end = length - if start < 0: - start += length - if http.is_byte_range_valid(start, end, length): - return start, min(end, length) - return None - - def make_content_range(self, length): - """Creates a :class:`~werkzeug.datastructures.ContentRange` object - from the current range and given content length. - """ - rng = self.range_for_length(length) - if rng is not None: - return ContentRange(self.units, rng[0], rng[1], length) - return None - - def to_header(self): - """Converts the object back into an HTTP header.""" - ranges = [] - for begin, end in self.ranges: - if end is None: - ranges.append(f"{begin}-" if begin >= 0 else str(begin)) - else: - ranges.append(f"{begin}-{end - 1}") - return f"{self.units}={','.join(ranges)}" - - def to_content_range_header(self, length): - """Converts the object into `Content-Range` HTTP header, - based on given length - """ - range = self.range_for_length(length) - if range is not None: - return f"{self.units} {range[0]}-{range[1] - 1}/{length}" - return None - - def __str__(self): - return self.to_header() - - def __repr__(self): - return f"<{type(self).__name__} {str(self)!r}>" - - -def _callback_property(name): - def fget(self): - return getattr(self, name) - - def fset(self, value): - setattr(self, name, value) - if self.on_update is not None: - self.on_update(self) - - return property(fget, fset) - - -class ContentRange: - """Represents the content range header. - - .. versionadded:: 0.7 - """ - - def __init__(self, units, start, stop, length=None, on_update=None): - assert http.is_byte_range_valid(start, stop, length), "Bad range provided" - self.on_update = on_update - self.set(start, stop, length, units) - - #: The units to use, usually "bytes" - units = _callback_property("_units") - #: The start point of the range or `None`. - start = _callback_property("_start") - #: The stop point of the range (non-inclusive) or `None`. Can only be - #: `None` if also start is `None`. - stop = _callback_property("_stop") - #: The length of the range or `None`. - length = _callback_property("_length") - - def set(self, start, stop, length=None, units="bytes"): - """Simple method to update the ranges.""" - assert http.is_byte_range_valid(start, stop, length), "Bad range provided" - self._units = units - self._start = start - self._stop = stop - self._length = length - if self.on_update is not None: - self.on_update(self) - - def unset(self): - """Sets the units to `None` which indicates that the header should - no longer be used. - """ - self.set(None, None, units=None) - - def to_header(self): - if self.units is None: - return "" - if self.length is None: - length = "*" - else: - length = self.length - if self.start is None: - return f"{self.units} */{length}" - return f"{self.units} {self.start}-{self.stop - 1}/{length}" - - def __bool__(self): - return self.units is not None - - def __str__(self): - return self.to_header() - - def __repr__(self): - return f"<{type(self).__name__} {str(self)!r}>" - - -class Authorization(ImmutableDictMixin, dict): - """Represents an ``Authorization`` header sent by the client. - - This is returned by - :func:`~werkzeug.http.parse_authorization_header`. It can be useful - to create the object manually to pass to the test - :class:`~werkzeug.test.Client`. - - .. versionchanged:: 0.5 - This object became immutable. - """ - - def __init__(self, auth_type, data=None): - dict.__init__(self, data or {}) - self.type = auth_type - - @property - def username(self): - """The username transmitted. This is set for both basic and digest - auth all the time. - """ - return self.get("username") - - @property - def password(self): - """When the authentication type is basic this is the password - transmitted by the client, else `None`. - """ - return self.get("password") - - @property - def realm(self): - """This is the server realm sent back for HTTP digest auth.""" - return self.get("realm") - - @property - def nonce(self): - """The nonce the server sent for digest auth, sent back by the client. - A nonce should be unique for every 401 response for HTTP digest auth. - """ - return self.get("nonce") - - @property - def uri(self): - """The URI from Request-URI of the Request-Line; duplicated because - proxies are allowed to change the Request-Line in transit. HTTP - digest auth only. - """ - return self.get("uri") - - @property - def nc(self): - """The nonce count value transmitted by clients if a qop-header is - also transmitted. HTTP digest auth only. - """ - return self.get("nc") - - @property - def cnonce(self): - """If the server sent a qop-header in the ``WWW-Authenticate`` - header, the client has to provide this value for HTTP digest auth. - See the RFC for more details. - """ - return self.get("cnonce") - - @property - def response(self): - """A string of 32 hex digits computed as defined in RFC 2617, which - proves that the user knows a password. Digest auth only. - """ - return self.get("response") - - @property - def opaque(self): - """The opaque header from the server returned unchanged by the client. - It is recommended that this string be base64 or hexadecimal data. - Digest auth only. - """ - return self.get("opaque") - - @property - def qop(self): - """Indicates what "quality of protection" the client has applied to - the message for HTTP digest auth. Note that this is a single token, - not a quoted list of alternatives as in WWW-Authenticate. - """ - return self.get("qop") - - def to_header(self): - """Convert to a string value for an ``Authorization`` header. - - .. versionadded:: 2.0 - Added to support passing authorization to the test client. - """ - if self.type == "basic": - value = base64.b64encode( - f"{self.username}:{self.password}".encode() - ).decode("utf8") - return f"Basic {value}" - - if self.type == "digest": - return f"Digest {http.dump_header(self)}" - - raise ValueError(f"Unsupported type {self.type!r}.") - - -def auth_property(name, doc=None): - """A static helper function for Authentication subclasses to add - extra authentication system properties onto a class:: - - class FooAuthenticate(WWWAuthenticate): - special_realm = auth_property('special_realm') - - For more information have a look at the sourcecode to see how the - regular properties (:attr:`realm` etc.) are implemented. - """ - - def _set_value(self, value): - if value is None: - self.pop(name, None) - else: - self[name] = str(value) - - return property(lambda x: x.get(name), _set_value, doc=doc) - - -def _set_property(name, doc=None): - def fget(self): - def on_update(header_set): - if not header_set and name in self: - del self[name] - elif header_set: - self[name] = header_set.to_header() - - return http.parse_set_header(self.get(name), on_update) - - return property(fget, doc=doc) - - -class WWWAuthenticate(UpdateDictMixin, dict): - """Provides simple access to `WWW-Authenticate` headers.""" - - #: list of keys that require quoting in the generated header - _require_quoting = frozenset(["domain", "nonce", "opaque", "realm", "qop"]) - - def __init__(self, auth_type=None, values=None, on_update=None): - dict.__init__(self, values or ()) - if auth_type: - self["__auth_type__"] = auth_type - self.on_update = on_update - - def set_basic(self, realm="authentication required"): - """Clear the auth info and enable basic auth.""" - dict.clear(self) - dict.update(self, {"__auth_type__": "basic", "realm": realm}) - if self.on_update: - self.on_update(self) - - def set_digest( - self, realm, nonce, qop=("auth",), opaque=None, algorithm=None, stale=False - ): - """Clear the auth info and enable digest auth.""" - d = { - "__auth_type__": "digest", - "realm": realm, - "nonce": nonce, - "qop": http.dump_header(qop), - } - if stale: - d["stale"] = "TRUE" - if opaque is not None: - d["opaque"] = opaque - if algorithm is not None: - d["algorithm"] = algorithm - dict.clear(self) - dict.update(self, d) - if self.on_update: - self.on_update(self) - - def to_header(self): - """Convert the stored values into a WWW-Authenticate header.""" - d = dict(self) - auth_type = d.pop("__auth_type__", None) or "basic" - kv_items = ( - (k, http.quote_header_value(v, allow_token=k not in self._require_quoting)) - for k, v in d.items() - ) - kv_string = ", ".join([f"{k}={v}" for k, v in kv_items]) - return f"{auth_type.title()} {kv_string}" - - def __str__(self): - return self.to_header() - - def __repr__(self): - return f"<{type(self).__name__} {self.to_header()!r}>" - - type = auth_property( - "__auth_type__", - doc="""The type of the auth mechanism. HTTP currently specifies - ``Basic`` and ``Digest``.""", - ) - realm = auth_property( - "realm", - doc="""A string to be displayed to users so they know which - username and password to use. This string should contain at - least the name of the host performing the authentication and - might additionally indicate the collection of users who might - have access.""", - ) - domain = _set_property( - "domain", - doc="""A list of URIs that define the protection space. If a URI - is an absolute path, it is relative to the canonical root URL of - the server being accessed.""", - ) - nonce = auth_property( - "nonce", - doc=""" - A server-specified data string which should be uniquely generated - each time a 401 response is made. It is recommended that this - string be base64 or hexadecimal data.""", - ) - opaque = auth_property( - "opaque", - doc="""A string of data, specified by the server, which should - be returned by the client unchanged in the Authorization header - of subsequent requests with URIs in the same protection space. - It is recommended that this string be base64 or hexadecimal - data.""", - ) - algorithm = auth_property( - "algorithm", - doc="""A string indicating a pair of algorithms used to produce - the digest and a checksum. If this is not present it is assumed - to be "MD5". If the algorithm is not understood, the challenge - should be ignored (and a different one used, if there is more - than one).""", - ) - qop = _set_property( - "qop", - doc="""A set of quality-of-privacy directives such as auth and - auth-int.""", - ) - - @property - def stale(self): - """A flag, indicating that the previous request from the client - was rejected because the nonce value was stale. - """ - val = self.get("stale") - if val is not None: - return val.lower() == "true" - - @stale.setter - def stale(self, value): - if value is None: - self.pop("stale", None) - else: - self["stale"] = "TRUE" if value else "FALSE" - - auth_property = staticmethod(auth_property) - - -class FileStorage: - """The :class:`FileStorage` class is a thin wrapper over incoming files. - It is used by the request object to represent uploaded files. All the - attributes of the wrapper stream are proxied by the file storage so - it's possible to do ``storage.read()`` instead of the long form - ``storage.stream.read()``. - """ - - def __init__( - self, - stream=None, - filename=None, - name=None, - content_type=None, - content_length=None, - headers=None, - ): - self.name = name - self.stream = stream or BytesIO() - - # If no filename is provided, attempt to get the filename from - # the stream object. Python names special streams like - # ```` with angular brackets, skip these streams. - if filename is None: - filename = getattr(stream, "name", None) - - if filename is not None: - filename = os.fsdecode(filename) - - if filename and filename[0] == "<" and filename[-1] == ">": - filename = None - else: - filename = os.fsdecode(filename) - - self.filename = filename - - if headers is None: - headers = Headers() - self.headers = headers - if content_type is not None: - headers["Content-Type"] = content_type - if content_length is not None: - headers["Content-Length"] = str(content_length) - - def _parse_content_type(self): - if not hasattr(self, "_parsed_content_type"): - self._parsed_content_type = http.parse_options_header(self.content_type) - - @property - def content_type(self): - """The content-type sent in the header. Usually not available""" - return self.headers.get("content-type") - - @property - def content_length(self): - """The content-length sent in the header. Usually not available""" - try: - return int(self.headers.get("content-length") or 0) - except ValueError: - return 0 - - @property - def mimetype(self): - """Like :attr:`content_type`, but without parameters (eg, without - charset, type etc.) and always lowercase. For example if the content - type is ``text/HTML; charset=utf-8`` the mimetype would be - ``'text/html'``. - - .. versionadded:: 0.7 - """ - self._parse_content_type() - return self._parsed_content_type[0].lower() - - @property - def mimetype_params(self): - """The mimetype parameters as dict. For example if the content - type is ``text/html; charset=utf-8`` the params would be - ``{'charset': 'utf-8'}``. - - .. versionadded:: 0.7 - """ - self._parse_content_type() - return self._parsed_content_type[1] - - def save(self, dst, buffer_size=16384): - """Save the file to a destination path or file object. If the - destination is a file object you have to close it yourself after the - call. The buffer size is the number of bytes held in memory during - the copy process. It defaults to 16KB. - - For secure file saving also have a look at :func:`secure_filename`. - - :param dst: a filename, :class:`os.PathLike`, or open file - object to write to. - :param buffer_size: Passed as the ``length`` parameter of - :func:`shutil.copyfileobj`. - - .. versionchanged:: 1.0 - Supports :mod:`pathlib`. - """ - from shutil import copyfileobj - - close_dst = False - - if hasattr(dst, "__fspath__"): - dst = fspath(dst) - - if isinstance(dst, str): - dst = open(dst, "wb") - close_dst = True - - try: - copyfileobj(self.stream, dst, buffer_size) - finally: - if close_dst: - dst.close() - - def close(self): - """Close the underlying file if possible.""" - try: - self.stream.close() - except Exception: - pass - - def __bool__(self): - return bool(self.filename) - - def __getattr__(self, name): - try: - return getattr(self.stream, name) - except AttributeError: - # SpooledTemporaryFile doesn't implement IOBase, get the - # attribute from its backing file instead. - # https://github.com/python/cpython/pull/3249 - if hasattr(self.stream, "_file"): - return getattr(self.stream._file, name) - raise - - def __iter__(self): - return iter(self.stream) - - def __repr__(self): - return f"<{type(self).__name__}: {self.filename!r} ({self.content_type!r})>" - - -# circular dependencies -from . import http diff --git a/src/werkzeug/datastructures.pyi b/src/werkzeug/datastructures.pyi deleted file mode 100644 index 7bf7297..0000000 --- a/src/werkzeug/datastructures.pyi +++ /dev/null @@ -1,921 +0,0 @@ -from datetime import datetime -from os import PathLike -from typing import Any -from typing import Callable -from typing import Collection -from typing import Dict -from typing import FrozenSet -from typing import Generic -from typing import Hashable -from typing import IO -from typing import Iterable -from typing import Iterator -from typing import List -from typing import Mapping -from typing import NoReturn -from typing import Optional -from typing import overload -from typing import Set -from typing import Tuple -from typing import Type -from typing import TypeVar -from typing import Union -from _typeshed import SupportsKeysAndGetItem -from _typeshed.wsgi import WSGIEnvironment - -from typing_extensions import Literal -from typing_extensions import SupportsIndex - -K = TypeVar("K") -V = TypeVar("V") -T = TypeVar("T") -D = TypeVar("D") -_CD = TypeVar("_CD", bound="CallbackDict") - -def is_immutable(self: object) -> NoReturn: ... -def iter_multi_items( - mapping: Union[Mapping[K, Union[V, Iterable[V]]], Iterable[Tuple[K, V]]] -) -> Iterator[Tuple[K, V]]: ... - -class ImmutableListMixin(List[V]): - _hash_cache: Optional[int] - def __hash__(self) -> int: ... # type: ignore - def __delitem__(self, key: Union[SupportsIndex, slice]) -> NoReturn: ... - def __iadd__(self, other: t.Any) -> NoReturn: ... # type: ignore - def __imul__(self, other: SupportsIndex) -> NoReturn: ... - def __setitem__( # type: ignore - self, key: Union[int, slice], value: V - ) -> NoReturn: ... - def append(self, value: V) -> NoReturn: ... - def remove(self, value: V) -> NoReturn: ... - def extend(self, values: Iterable[V]) -> NoReturn: ... - def insert(self, pos: SupportsIndex, value: V) -> NoReturn: ... - def pop(self, index: SupportsIndex = -1) -> NoReturn: ... - def reverse(self) -> NoReturn: ... - def sort( - self, key: Optional[Callable[[V], Any]] = None, reverse: bool = False - ) -> NoReturn: ... - -class ImmutableList(ImmutableListMixin[V]): ... - -class ImmutableDictMixin(Dict[K, V]): - _hash_cache: Optional[int] - @classmethod - def fromkeys( # type: ignore - cls, keys: Iterable[K], value: Optional[V] = None - ) -> ImmutableDictMixin[K, V]: ... - def _iter_hashitems(self) -> Iterable[Hashable]: ... - def __hash__(self) -> int: ... # type: ignore - def setdefault(self, key: K, default: Optional[V] = None) -> NoReturn: ... - def update(self, *args: Any, **kwargs: V) -> NoReturn: ... - def pop(self, key: K, default: Optional[V] = None) -> NoReturn: ... # type: ignore - def popitem(self) -> NoReturn: ... - def __setitem__(self, key: K, value: V) -> NoReturn: ... - def __delitem__(self, key: K) -> NoReturn: ... - def clear(self) -> NoReturn: ... - -class ImmutableMultiDictMixin(ImmutableDictMixin[K, V]): - def _iter_hashitems(self) -> Iterable[Hashable]: ... - def add(self, key: K, value: V) -> NoReturn: ... - def popitemlist(self) -> NoReturn: ... - def poplist(self, key: K) -> NoReturn: ... - def setlist(self, key: K, new_list: Iterable[V]) -> NoReturn: ... - def setlistdefault( - self, key: K, default_list: Optional[Iterable[V]] = None - ) -> NoReturn: ... - -def _calls_update(name: str) -> Callable[[UpdateDictMixin[K, V]], Any]: ... - -class UpdateDictMixin(Dict[K, V]): - on_update: Optional[Callable[[UpdateDictMixin[K, V]], None]] - def setdefault(self, key: K, default: Optional[V] = None) -> V: ... - @overload - def pop(self, key: K) -> V: ... - @overload - def pop(self, key: K, default: Union[V, T] = ...) -> Union[V, T]: ... - def __setitem__(self, key: K, value: V) -> None: ... - def __delitem__(self, key: K) -> None: ... - def clear(self) -> None: ... - def popitem(self) -> Tuple[K, V]: ... - @overload - def update(self, __m: SupportsKeysAndGetItem[K, V], **kwargs: V) -> None: ... - @overload - def update(self, __m: Iterable[Tuple[K, V]], **kwargs: V) -> None: ... - @overload - def update(self, **kwargs: V) -> None: ... - -class TypeConversionDict(Dict[K, V]): - @overload - def get(self, key: K, default: None = ..., type: None = ...) -> Optional[V]: ... - @overload - def get(self, key: K, default: D, type: None = ...) -> Union[D, V]: ... - @overload - def get(self, key: K, default: D, type: Callable[[V], T]) -> Union[D, T]: ... - @overload - def get(self, key: K, type: Callable[[V], T]) -> Optional[T]: ... - -class ImmutableTypeConversionDict(ImmutableDictMixin[K, V], TypeConversionDict[K, V]): - def copy(self) -> TypeConversionDict[K, V]: ... - def __copy__(self) -> ImmutableTypeConversionDict: ... - -class MultiDict(TypeConversionDict[K, V]): - def __init__( - self, - mapping: Optional[ - Union[Mapping[K, Union[Iterable[V], V]], Iterable[Tuple[K, V]]] - ] = None, - ) -> None: ... - def __getitem__(self, item: K) -> V: ... - def __setitem__(self, key: K, value: V) -> None: ... - def add(self, key: K, value: V) -> None: ... - @overload - def getlist(self, key: K) -> List[V]: ... - @overload - def getlist(self, key: K, type: Callable[[V], T] = ...) -> List[T]: ... - def setlist(self, key: K, new_list: Iterable[V]) -> None: ... - def setdefault(self, key: K, default: Optional[V] = None) -> V: ... - def setlistdefault( - self, key: K, default_list: Optional[Iterable[V]] = None - ) -> List[V]: ... - def items(self, multi: bool = False) -> Iterator[Tuple[K, V]]: ... # type: ignore - def lists(self) -> Iterator[Tuple[K, List[V]]]: ... - def values(self) -> Iterator[V]: ... # type: ignore - def listvalues(self) -> Iterator[List[V]]: ... - def copy(self) -> MultiDict[K, V]: ... - def deepcopy(self, memo: Any = None) -> MultiDict[K, V]: ... - @overload - def to_dict(self) -> Dict[K, V]: ... - @overload - def to_dict(self, flat: Literal[False]) -> Dict[K, List[V]]: ... - def update( # type: ignore - self, mapping: Union[Mapping[K, Union[Iterable[V], V]], Iterable[Tuple[K, V]]] - ) -> None: ... - @overload - def pop(self, key: K) -> V: ... - @overload - def pop(self, key: K, default: Union[V, T] = ...) -> Union[V, T]: ... - def popitem(self) -> Tuple[K, V]: ... - def poplist(self, key: K) -> List[V]: ... - def popitemlist(self) -> Tuple[K, List[V]]: ... - def __copy__(self) -> MultiDict[K, V]: ... - def __deepcopy__(self, memo: Any) -> MultiDict[K, V]: ... - -class _omd_bucket(Generic[K, V]): - prev: Optional[_omd_bucket] - next: Optional[_omd_bucket] - key: K - value: V - def __init__(self, omd: OrderedMultiDict, key: K, value: V) -> None: ... - def unlink(self, omd: OrderedMultiDict) -> None: ... - -class OrderedMultiDict(MultiDict[K, V]): - _first_bucket: Optional[_omd_bucket] - _last_bucket: Optional[_omd_bucket] - def __init__(self, mapping: Optional[Mapping[K, V]] = None) -> None: ... - def __eq__(self, other: object) -> bool: ... - def __getitem__(self, key: K) -> V: ... - def __setitem__(self, key: K, value: V) -> None: ... - def __delitem__(self, key: K) -> None: ... - def keys(self) -> Iterator[K]: ... # type: ignore - def __iter__(self) -> Iterator[K]: ... - def values(self) -> Iterator[V]: ... # type: ignore - def items(self, multi: bool = False) -> Iterator[Tuple[K, V]]: ... # type: ignore - def lists(self) -> Iterator[Tuple[K, List[V]]]: ... - def listvalues(self) -> Iterator[List[V]]: ... - def add(self, key: K, value: V) -> None: ... - @overload - def getlist(self, key: K) -> List[V]: ... - @overload - def getlist(self, key: K, type: Callable[[V], T] = ...) -> List[T]: ... - def setlist(self, key: K, new_list: Iterable[V]) -> None: ... - def setlistdefault( - self, key: K, default_list: Optional[Iterable[V]] = None - ) -> List[V]: ... - def update( # type: ignore - self, mapping: Union[Mapping[K, V], Iterable[Tuple[K, V]]] - ) -> None: ... - def poplist(self, key: K) -> List[V]: ... - @overload - def pop(self, key: K) -> V: ... - @overload - def pop(self, key: K, default: Union[V, T] = ...) -> Union[V, T]: ... - def popitem(self) -> Tuple[K, V]: ... - def popitemlist(self) -> Tuple[K, List[V]]: ... - -def _options_header_vkw( - value: str, kw: Mapping[str, Optional[Union[str, int]]] -) -> str: ... -def _unicodify_header_value(value: Union[str, int]) -> str: ... - -HV = Union[str, int] - -class Headers(Dict[str, str]): - _list: List[Tuple[str, str]] - def __init__( - self, - defaults: Optional[ - Union[Mapping[str, Union[HV, Iterable[HV]]], Iterable[Tuple[str, HV]]] - ] = None, - ) -> None: ... - @overload - def __getitem__(self, key: str) -> str: ... - @overload - def __getitem__(self, key: int) -> Tuple[str, str]: ... - @overload - def __getitem__(self, key: slice) -> Headers: ... - @overload - def __getitem__(self, key: str, _get_mode: Literal[True] = ...) -> str: ... - def __eq__(self, other: object) -> bool: ... - @overload # type: ignore - def get(self, key: str, default: str) -> str: ... - @overload - def get(self, key: str, default: Optional[str] = None) -> Optional[str]: ... - @overload - def get( - self, key: str, default: Optional[T] = None, type: Callable[[str], T] = ... - ) -> Optional[T]: ... - @overload - def getlist(self, key: str) -> List[str]: ... - @overload - def getlist(self, key: str, type: Callable[[str], T]) -> List[T]: ... - def get_all(self, name: str) -> List[str]: ... - def items( # type: ignore - self, lower: bool = False - ) -> Iterator[Tuple[str, str]]: ... - def keys(self, lower: bool = False) -> Iterator[str]: ... # type: ignore - def values(self) -> Iterator[str]: ... # type: ignore - def extend( - self, - *args: Union[Mapping[str, Union[HV, Iterable[HV]]], Iterable[Tuple[str, HV]]], - **kwargs: Union[HV, Iterable[HV]], - ) -> None: ... - @overload - def __delitem__(self, key: Union[str, int, slice]) -> None: ... - @overload - def __delitem__(self, key: str, _index_operation: Literal[False]) -> None: ... - def remove(self, key: str) -> None: ... - @overload # type: ignore - def pop(self, key: str, default: Optional[str] = None) -> str: ... - @overload - def pop( - self, key: Optional[int] = None, default: Optional[Tuple[str, str]] = None - ) -> Tuple[str, str]: ... - def popitem(self) -> Tuple[str, str]: ... - def __contains__(self, key: str) -> bool: ... # type: ignore - def has_key(self, key: str) -> bool: ... - def __iter__(self) -> Iterator[Tuple[str, str]]: ... # type: ignore - def add(self, _key: str, _value: HV, **kw: HV) -> None: ... - def _validate_value(self, value: str) -> None: ... - def add_header(self, _key: str, _value: HV, **_kw: HV) -> None: ... - def clear(self) -> None: ... - def set(self, _key: str, _value: HV, **kw: HV) -> None: ... - def setlist(self, key: str, values: Iterable[HV]) -> None: ... - def setdefault(self, key: str, default: HV) -> str: ... # type: ignore - def setlistdefault(self, key: str, default: Iterable[HV]) -> None: ... - @overload - def __setitem__(self, key: str, value: HV) -> None: ... - @overload - def __setitem__(self, key: int, value: Tuple[str, HV]) -> None: ... - @overload - def __setitem__(self, key: slice, value: Iterable[Tuple[str, HV]]) -> None: ... - @overload - def update( - self, __m: SupportsKeysAndGetItem[str, HV], **kwargs: Union[HV, Iterable[HV]] - ) -> None: ... - @overload - def update( - self, __m: Iterable[Tuple[str, HV]], **kwargs: Union[HV, Iterable[HV]] - ) -> None: ... - @overload - def update(self, **kwargs: Union[HV, Iterable[HV]]) -> None: ... - def to_wsgi_list(self) -> List[Tuple[str, str]]: ... - def copy(self) -> Headers: ... - def __copy__(self) -> Headers: ... - -class ImmutableHeadersMixin(Headers): - def __delitem__(self, key: Any, _index_operation: bool = True) -> NoReturn: ... - def __setitem__(self, key: Any, value: Any) -> NoReturn: ... - def set(self, _key: Any, _value: Any, **kw: Any) -> NoReturn: ... - def setlist(self, key: Any, values: Any) -> NoReturn: ... - def add(self, _key: Any, _value: Any, **kw: Any) -> NoReturn: ... - def add_header(self, _key: Any, _value: Any, **_kw: Any) -> NoReturn: ... - def remove(self, key: Any) -> NoReturn: ... - def extend(self, *args: Any, **kwargs: Any) -> NoReturn: ... - def update(self, *args: Any, **kwargs: Any) -> NoReturn: ... - def insert(self, pos: Any, value: Any) -> NoReturn: ... - def pop(self, key: Any = None, default: Any = ...) -> NoReturn: ... - def popitem(self) -> NoReturn: ... - def setdefault(self, key: Any, default: Any) -> NoReturn: ... # type: ignore - def setlistdefault(self, key: Any, default: Any) -> NoReturn: ... - -class EnvironHeaders(ImmutableHeadersMixin, Headers): - environ: WSGIEnvironment - def __init__(self, environ: WSGIEnvironment) -> None: ... - def __eq__(self, other: object) -> bool: ... - def __getitem__( # type: ignore - self, key: str, _get_mode: Literal[False] = False - ) -> str: ... - def __iter__(self) -> Iterator[Tuple[str, str]]: ... # type: ignore - def copy(self) -> NoReturn: ... - -class CombinedMultiDict(ImmutableMultiDictMixin[K, V], MultiDict[K, V]): # type: ignore - dicts: List[MultiDict[K, V]] - def __init__(self, dicts: Optional[Iterable[MultiDict[K, V]]]) -> None: ... - @classmethod - def fromkeys(cls, keys: Any, value: Any = None) -> NoReturn: ... - def __getitem__(self, key: K) -> V: ... - @overload # type: ignore - def get(self, key: K) -> Optional[V]: ... - @overload - def get(self, key: K, default: Union[V, T] = ...) -> Union[V, T]: ... - @overload - def get( - self, key: K, default: Optional[T] = None, type: Callable[[V], T] = ... - ) -> Optional[T]: ... - @overload - def getlist(self, key: K) -> List[V]: ... - @overload - def getlist(self, key: K, type: Callable[[V], T] = ...) -> List[T]: ... - def _keys_impl(self) -> Set[K]: ... - def keys(self) -> Set[K]: ... # type: ignore - def __iter__(self) -> Set[K]: ... # type: ignore - def items(self, multi: bool = False) -> Iterator[Tuple[K, V]]: ... # type: ignore - def values(self) -> Iterator[V]: ... # type: ignore - def lists(self) -> Iterator[Tuple[K, List[V]]]: ... - def listvalues(self) -> Iterator[List[V]]: ... - def copy(self) -> MultiDict[K, V]: ... - @overload - def to_dict(self) -> Dict[K, V]: ... - @overload - def to_dict(self, flat: Literal[False]) -> Dict[K, List[V]]: ... - def __contains__(self, key: K) -> bool: ... # type: ignore - def has_key(self, key: K) -> bool: ... - -class FileMultiDict(MultiDict[str, "FileStorage"]): - def add_file( - self, - name: str, - file: Union[FileStorage, str, IO[bytes]], - filename: Optional[str] = None, - content_type: Optional[str] = None, - ) -> None: ... - -class ImmutableDict(ImmutableDictMixin[K, V], Dict[K, V]): - def copy(self) -> Dict[K, V]: ... - def __copy__(self) -> ImmutableDict[K, V]: ... - -class ImmutableMultiDict( # type: ignore - ImmutableMultiDictMixin[K, V], MultiDict[K, V] -): - def copy(self) -> MultiDict[K, V]: ... - def __copy__(self) -> ImmutableMultiDict[K, V]: ... - -class ImmutableOrderedMultiDict( # type: ignore - ImmutableMultiDictMixin[K, V], OrderedMultiDict[K, V] -): - def _iter_hashitems(self) -> Iterator[Tuple[int, Tuple[K, V]]]: ... - def copy(self) -> OrderedMultiDict[K, V]: ... - def __copy__(self) -> ImmutableOrderedMultiDict[K, V]: ... - -class Accept(ImmutableList[Tuple[str, int]]): - provided: bool - def __init__( - self, values: Optional[Union[Accept, Iterable[Tuple[str, float]]]] = None - ) -> None: ... - def _specificity(self, value: str) -> Tuple[bool, ...]: ... - def _value_matches(self, value: str, item: str) -> bool: ... - @overload # type: ignore - def __getitem__(self, key: str) -> int: ... - @overload - def __getitem__(self, key: int) -> Tuple[str, int]: ... - @overload - def __getitem__(self, key: slice) -> Iterable[Tuple[str, int]]: ... - def quality(self, key: str) -> int: ... - def __contains__(self, value: str) -> bool: ... # type: ignore - def index(self, key: str) -> int: ... # type: ignore - def find(self, key: str) -> int: ... - def values(self) -> Iterator[str]: ... - def to_header(self) -> str: ... - def _best_single_match(self, match: str) -> Optional[Tuple[str, int]]: ... - def best_match( - self, matches: Iterable[str], default: Optional[str] = None - ) -> Optional[str]: ... - @property - def best(self) -> str: ... - -def _normalize_mime(value: str) -> List[str]: ... - -class MIMEAccept(Accept): - def _specificity(self, value: str) -> Tuple[bool, ...]: ... - def _value_matches(self, value: str, item: str) -> bool: ... - @property - def accept_html(self) -> bool: ... - @property - def accept_xhtml(self) -> bool: ... - @property - def accept_json(self) -> bool: ... - -def _normalize_lang(value: str) -> List[str]: ... - -class LanguageAccept(Accept): - def _value_matches(self, value: str, item: str) -> bool: ... - def best_match( - self, matches: Iterable[str], default: Optional[str] = None - ) -> Optional[str]: ... - -class CharsetAccept(Accept): - def _value_matches(self, value: str, item: str) -> bool: ... - -_CPT = TypeVar("_CPT", str, int, bool) -_OptCPT = Optional[_CPT] - -def cache_control_property(key: str, empty: _OptCPT, type: Type[_CPT]) -> property: ... - -class _CacheControl(UpdateDictMixin[str, _OptCPT], Dict[str, _OptCPT]): - provided: bool - def __init__( - self, - values: Union[Mapping[str, _OptCPT], Iterable[Tuple[str, _OptCPT]]] = (), - on_update: Optional[Callable[[_CacheControl], None]] = None, - ) -> None: ... - @property - def no_cache(self) -> Optional[bool]: ... - @no_cache.setter - def no_cache(self, value: Optional[bool]) -> None: ... - @no_cache.deleter - def no_cache(self) -> None: ... - @property - def no_store(self) -> Optional[bool]: ... - @no_store.setter - def no_store(self, value: Optional[bool]) -> None: ... - @no_store.deleter - def no_store(self) -> None: ... - @property - def max_age(self) -> Optional[int]: ... - @max_age.setter - def max_age(self, value: Optional[int]) -> None: ... - @max_age.deleter - def max_age(self) -> None: ... - @property - def no_transform(self) -> Optional[bool]: ... - @no_transform.setter - def no_transform(self, value: Optional[bool]) -> None: ... - @no_transform.deleter - def no_transform(self) -> None: ... - def _get_cache_value(self, key: str, empty: Optional[T], type: Type[T]) -> T: ... - def _set_cache_value(self, key: str, value: Optional[T], type: Type[T]) -> None: ... - def _del_cache_value(self, key: str) -> None: ... - def to_header(self) -> str: ... - @staticmethod - def cache_property(key: str, empty: _OptCPT, type: Type[_CPT]) -> property: ... - -class RequestCacheControl(ImmutableDictMixin[str, _OptCPT], _CacheControl): - @property - def max_stale(self) -> Optional[int]: ... - @max_stale.setter - def max_stale(self, value: Optional[int]) -> None: ... - @max_stale.deleter - def max_stale(self) -> None: ... - @property - def min_fresh(self) -> Optional[int]: ... - @min_fresh.setter - def min_fresh(self, value: Optional[int]) -> None: ... - @min_fresh.deleter - def min_fresh(self) -> None: ... - @property - def only_if_cached(self) -> Optional[bool]: ... - @only_if_cached.setter - def only_if_cached(self, value: Optional[bool]) -> None: ... - @only_if_cached.deleter - def only_if_cached(self) -> None: ... - -class ResponseCacheControl(_CacheControl): - @property - def public(self) -> Optional[bool]: ... - @public.setter - def public(self, value: Optional[bool]) -> None: ... - @public.deleter - def public(self) -> None: ... - @property - def private(self) -> Optional[bool]: ... - @private.setter - def private(self, value: Optional[bool]) -> None: ... - @private.deleter - def private(self) -> None: ... - @property - def must_revalidate(self) -> Optional[bool]: ... - @must_revalidate.setter - def must_revalidate(self, value: Optional[bool]) -> None: ... - @must_revalidate.deleter - def must_revalidate(self) -> None: ... - @property - def proxy_revalidate(self) -> Optional[bool]: ... - @proxy_revalidate.setter - def proxy_revalidate(self, value: Optional[bool]) -> None: ... - @proxy_revalidate.deleter - def proxy_revalidate(self) -> None: ... - @property - def s_maxage(self) -> Optional[int]: ... - @s_maxage.setter - def s_maxage(self, value: Optional[int]) -> None: ... - @s_maxage.deleter - def s_maxage(self) -> None: ... - @property - def immutable(self) -> Optional[bool]: ... - @immutable.setter - def immutable(self, value: Optional[bool]) -> None: ... - @immutable.deleter - def immutable(self) -> None: ... - -def csp_property(key: str) -> property: ... - -class ContentSecurityPolicy(UpdateDictMixin[str, str], Dict[str, str]): - @property - def base_uri(self) -> Optional[str]: ... - @base_uri.setter - def base_uri(self, value: Optional[str]) -> None: ... - @base_uri.deleter - def base_uri(self) -> None: ... - @property - def child_src(self) -> Optional[str]: ... - @child_src.setter - def child_src(self, value: Optional[str]) -> None: ... - @child_src.deleter - def child_src(self) -> None: ... - @property - def connect_src(self) -> Optional[str]: ... - @connect_src.setter - def connect_src(self, value: Optional[str]) -> None: ... - @connect_src.deleter - def connect_src(self) -> None: ... - @property - def default_src(self) -> Optional[str]: ... - @default_src.setter - def default_src(self, value: Optional[str]) -> None: ... - @default_src.deleter - def default_src(self) -> None: ... - @property - def font_src(self) -> Optional[str]: ... - @font_src.setter - def font_src(self, value: Optional[str]) -> None: ... - @font_src.deleter - def font_src(self) -> None: ... - @property - def form_action(self) -> Optional[str]: ... - @form_action.setter - def form_action(self, value: Optional[str]) -> None: ... - @form_action.deleter - def form_action(self) -> None: ... - @property - def frame_ancestors(self) -> Optional[str]: ... - @frame_ancestors.setter - def frame_ancestors(self, value: Optional[str]) -> None: ... - @frame_ancestors.deleter - def frame_ancestors(self) -> None: ... - @property - def frame_src(self) -> Optional[str]: ... - @frame_src.setter - def frame_src(self, value: Optional[str]) -> None: ... - @frame_src.deleter - def frame_src(self) -> None: ... - @property - def img_src(self) -> Optional[str]: ... - @img_src.setter - def img_src(self, value: Optional[str]) -> None: ... - @img_src.deleter - def img_src(self) -> None: ... - @property - def manifest_src(self) -> Optional[str]: ... - @manifest_src.setter - def manifest_src(self, value: Optional[str]) -> None: ... - @manifest_src.deleter - def manifest_src(self) -> None: ... - @property - def media_src(self) -> Optional[str]: ... - @media_src.setter - def media_src(self, value: Optional[str]) -> None: ... - @media_src.deleter - def media_src(self) -> None: ... - @property - def navigate_to(self) -> Optional[str]: ... - @navigate_to.setter - def navigate_to(self, value: Optional[str]) -> None: ... - @navigate_to.deleter - def navigate_to(self) -> None: ... - @property - def object_src(self) -> Optional[str]: ... - @object_src.setter - def object_src(self, value: Optional[str]) -> None: ... - @object_src.deleter - def object_src(self) -> None: ... - @property - def prefetch_src(self) -> Optional[str]: ... - @prefetch_src.setter - def prefetch_src(self, value: Optional[str]) -> None: ... - @prefetch_src.deleter - def prefetch_src(self) -> None: ... - @property - def plugin_types(self) -> Optional[str]: ... - @plugin_types.setter - def plugin_types(self, value: Optional[str]) -> None: ... - @plugin_types.deleter - def plugin_types(self) -> None: ... - @property - def report_to(self) -> Optional[str]: ... - @report_to.setter - def report_to(self, value: Optional[str]) -> None: ... - @report_to.deleter - def report_to(self) -> None: ... - @property - def report_uri(self) -> Optional[str]: ... - @report_uri.setter - def report_uri(self, value: Optional[str]) -> None: ... - @report_uri.deleter - def report_uri(self) -> None: ... - @property - def sandbox(self) -> Optional[str]: ... - @sandbox.setter - def sandbox(self, value: Optional[str]) -> None: ... - @sandbox.deleter - def sandbox(self) -> None: ... - @property - def script_src(self) -> Optional[str]: ... - @script_src.setter - def script_src(self, value: Optional[str]) -> None: ... - @script_src.deleter - def script_src(self) -> None: ... - @property - def script_src_attr(self) -> Optional[str]: ... - @script_src_attr.setter - def script_src_attr(self, value: Optional[str]) -> None: ... - @script_src_attr.deleter - def script_src_attr(self) -> None: ... - @property - def script_src_elem(self) -> Optional[str]: ... - @script_src_elem.setter - def script_src_elem(self, value: Optional[str]) -> None: ... - @script_src_elem.deleter - def script_src_elem(self) -> None: ... - @property - def style_src(self) -> Optional[str]: ... - @style_src.setter - def style_src(self, value: Optional[str]) -> None: ... - @style_src.deleter - def style_src(self) -> None: ... - @property - def style_src_attr(self) -> Optional[str]: ... - @style_src_attr.setter - def style_src_attr(self, value: Optional[str]) -> None: ... - @style_src_attr.deleter - def style_src_attr(self) -> None: ... - @property - def style_src_elem(self) -> Optional[str]: ... - @style_src_elem.setter - def style_src_elem(self, value: Optional[str]) -> None: ... - @style_src_elem.deleter - def style_src_elem(self) -> None: ... - @property - def worker_src(self) -> Optional[str]: ... - @worker_src.setter - def worker_src(self, value: Optional[str]) -> None: ... - @worker_src.deleter - def worker_src(self) -> None: ... - provided: bool - def __init__( - self, - values: Union[Mapping[str, str], Iterable[Tuple[str, str]]] = (), - on_update: Optional[Callable[[ContentSecurityPolicy], None]] = None, - ) -> None: ... - def _get_value(self, key: str) -> Optional[str]: ... - def _set_value(self, key: str, value: str) -> None: ... - def _del_value(self, key: str) -> None: ... - def to_header(self) -> str: ... - -class CallbackDict(UpdateDictMixin[K, V], Dict[K, V]): - def __init__( - self, - initial: Optional[Union[Mapping[K, V], Iterable[Tuple[K, V]]]] = None, - on_update: Optional[Callable[[_CD], None]] = None, - ) -> None: ... - -class HeaderSet(Set[str]): - _headers: List[str] - _set: Set[str] - on_update: Optional[Callable[[HeaderSet], None]] - def __init__( - self, - headers: Optional[Iterable[str]] = None, - on_update: Optional[Callable[[HeaderSet], None]] = None, - ) -> None: ... - def add(self, header: str) -> None: ... - def remove(self, header: str) -> None: ... - def update(self, iterable: Iterable[str]) -> None: ... # type: ignore - def discard(self, header: str) -> None: ... - def find(self, header: str) -> int: ... - def index(self, header: str) -> int: ... - def clear(self) -> None: ... - def as_set(self, preserve_casing: bool = False) -> Set[str]: ... - def to_header(self) -> str: ... - def __getitem__(self, idx: int) -> str: ... - def __delitem__(self, idx: int) -> None: ... - def __setitem__(self, idx: int, value: str) -> None: ... - def __contains__(self, header: str) -> bool: ... # type: ignore - def __len__(self) -> int: ... - def __iter__(self) -> Iterator[str]: ... - -class ETags(Collection[str]): - _strong: FrozenSet[str] - _weak: FrozenSet[str] - star_tag: bool - def __init__( - self, - strong_etags: Optional[Iterable[str]] = None, - weak_etags: Optional[Iterable[str]] = None, - star_tag: bool = False, - ) -> None: ... - def as_set(self, include_weak: bool = False) -> Set[str]: ... - def is_weak(self, etag: str) -> bool: ... - def is_strong(self, etag: str) -> bool: ... - def contains_weak(self, etag: str) -> bool: ... - def contains(self, etag: str) -> bool: ... - def contains_raw(self, etag: str) -> bool: ... - def to_header(self) -> str: ... - def __call__( - self, - etag: Optional[str] = None, - data: Optional[bytes] = None, - include_weak: bool = False, - ) -> bool: ... - def __len__(self) -> int: ... - def __iter__(self) -> Iterator[str]: ... - def __contains__(self, item: str) -> bool: ... # type: ignore - -class IfRange: - etag: Optional[str] - date: Optional[datetime] - def __init__( - self, etag: Optional[str] = None, date: Optional[datetime] = None - ) -> None: ... - def to_header(self) -> str: ... - -class Range: - units: str - ranges: List[Tuple[int, Optional[int]]] - def __init__(self, units: str, ranges: List[Tuple[int, Optional[int]]]) -> None: ... - def range_for_length(self, length: Optional[int]) -> Optional[Tuple[int, int]]: ... - def make_content_range(self, length: Optional[int]) -> Optional[ContentRange]: ... - def to_header(self) -> str: ... - def to_content_range_header(self, length: Optional[int]) -> Optional[str]: ... - -def _callback_property(name: str) -> property: ... - -class ContentRange: - on_update: Optional[Callable[[ContentRange], None]] - def __init__( - self, - units: Optional[str], - start: Optional[int], - stop: Optional[int], - length: Optional[int] = None, - on_update: Optional[Callable[[ContentRange], None]] = None, - ) -> None: ... - @property - def units(self) -> Optional[str]: ... - @units.setter - def units(self, value: Optional[str]) -> None: ... - @property - def start(self) -> Optional[int]: ... - @start.setter - def start(self, value: Optional[int]) -> None: ... - @property - def stop(self) -> Optional[int]: ... - @stop.setter - def stop(self, value: Optional[int]) -> None: ... - @property - def length(self) -> Optional[int]: ... - @length.setter - def length(self, value: Optional[int]) -> None: ... - def set( - self, - start: Optional[int], - stop: Optional[int], - length: Optional[int] = None, - units: Optional[str] = "bytes", - ) -> None: ... - def unset(self) -> None: ... - def to_header(self) -> str: ... - -class Authorization(ImmutableDictMixin[str, str], Dict[str, str]): - type: str - def __init__( - self, - auth_type: str, - data: Optional[Union[Mapping[str, str], Iterable[Tuple[str, str]]]] = None, - ) -> None: ... - @property - def username(self) -> Optional[str]: ... - @property - def password(self) -> Optional[str]: ... - @property - def realm(self) -> Optional[str]: ... - @property - def nonce(self) -> Optional[str]: ... - @property - def uri(self) -> Optional[str]: ... - @property - def nc(self) -> Optional[str]: ... - @property - def cnonce(self) -> Optional[str]: ... - @property - def response(self) -> Optional[str]: ... - @property - def opaque(self) -> Optional[str]: ... - @property - def qop(self) -> Optional[str]: ... - def to_header(self) -> str: ... - -def auth_property(name: str, doc: Optional[str] = None) -> property: ... -def _set_property(name: str, doc: Optional[str] = None) -> property: ... - -class WWWAuthenticate(UpdateDictMixin[str, str], Dict[str, str]): - _require_quoting: FrozenSet[str] - def __init__( - self, - auth_type: Optional[str] = None, - values: Optional[Union[Mapping[str, str], Iterable[Tuple[str, str]]]] = None, - on_update: Optional[Callable[[WWWAuthenticate], None]] = None, - ) -> None: ... - def set_basic(self, realm: str = ...) -> None: ... - def set_digest( - self, - realm: str, - nonce: str, - qop: Iterable[str] = ("auth",), - opaque: Optional[str] = None, - algorithm: Optional[str] = None, - stale: bool = False, - ) -> None: ... - def to_header(self) -> str: ... - @property - def type(self) -> Optional[str]: ... - @type.setter - def type(self, value: Optional[str]) -> None: ... - @property - def realm(self) -> Optional[str]: ... - @realm.setter - def realm(self, value: Optional[str]) -> None: ... - @property - def domain(self) -> HeaderSet: ... - @property - def nonce(self) -> Optional[str]: ... - @nonce.setter - def nonce(self, value: Optional[str]) -> None: ... - @property - def opaque(self) -> Optional[str]: ... - @opaque.setter - def opaque(self, value: Optional[str]) -> None: ... - @property - def algorithm(self) -> Optional[str]: ... - @algorithm.setter - def algorithm(self, value: Optional[str]) -> None: ... - @property - def qop(self) -> HeaderSet: ... - @property - def stale(self) -> Optional[bool]: ... - @stale.setter - def stale(self, value: Optional[bool]) -> None: ... - @staticmethod - def auth_property(name: str, doc: Optional[str] = None) -> property: ... - -class FileStorage: - name: Optional[str] - stream: IO[bytes] - filename: Optional[str] - headers: Headers - _parsed_content_type: Tuple[str, Dict[str, str]] - def __init__( - self, - stream: Optional[IO[bytes]] = None, - filename: Union[str, PathLike, None] = None, - name: Optional[str] = None, - content_type: Optional[str] = None, - content_length: Optional[int] = None, - headers: Optional[Headers] = None, - ) -> None: ... - def _parse_content_type(self) -> None: ... - @property - def content_type(self) -> str: ... - @property - def content_length(self) -> int: ... - @property - def mimetype(self) -> str: ... - @property - def mimetype_params(self) -> Dict[str, str]: ... - def save( - self, dst: Union[str, PathLike, IO[bytes]], buffer_size: int = ... - ) -> None: ... - def close(self) -> None: ... - def __bool__(self) -> bool: ... - def __getattr__(self, name: str) -> Any: ... - def __iter__(self) -> Iterator[bytes]: ... - def __repr__(self) -> str: ... diff --git a/src/werkzeug/datastructures/__init__.py b/src/werkzeug/datastructures/__init__.py new file mode 100644 index 0000000..846ffce --- /dev/null +++ b/src/werkzeug/datastructures/__init__.py @@ -0,0 +1,34 @@ +from .accept import Accept as Accept +from .accept import CharsetAccept as CharsetAccept +from .accept import LanguageAccept as LanguageAccept +from .accept import MIMEAccept as MIMEAccept +from .auth import Authorization as Authorization +from .auth import WWWAuthenticate as WWWAuthenticate +from .cache_control import RequestCacheControl as RequestCacheControl +from .cache_control import ResponseCacheControl as ResponseCacheControl +from .csp import ContentSecurityPolicy as ContentSecurityPolicy +from .etag import ETags as ETags +from .file_storage import FileMultiDict as FileMultiDict +from .file_storage import FileStorage as FileStorage +from .headers import EnvironHeaders as EnvironHeaders +from .headers import Headers as Headers +from .mixins import ImmutableDictMixin as ImmutableDictMixin +from .mixins import ImmutableHeadersMixin as ImmutableHeadersMixin +from .mixins import ImmutableListMixin as ImmutableListMixin +from .mixins import ImmutableMultiDictMixin as ImmutableMultiDictMixin +from .mixins import UpdateDictMixin as UpdateDictMixin +from .range import ContentRange as ContentRange +from .range import IfRange as IfRange +from .range import Range as Range +from .structures import CallbackDict as CallbackDict +from .structures import CombinedMultiDict as CombinedMultiDict +from .structures import HeaderSet as HeaderSet +from .structures import ImmutableDict as ImmutableDict +from .structures import ImmutableList as ImmutableList +from .structures import ImmutableMultiDict as ImmutableMultiDict +from .structures import ImmutableOrderedMultiDict as ImmutableOrderedMultiDict +from .structures import ImmutableTypeConversionDict as ImmutableTypeConversionDict +from .structures import iter_multi_items as iter_multi_items +from .structures import MultiDict as MultiDict +from .structures import OrderedMultiDict as OrderedMultiDict +from .structures import TypeConversionDict as TypeConversionDict diff --git a/src/werkzeug/datastructures/accept.py b/src/werkzeug/datastructures/accept.py new file mode 100644 index 0000000..d80f0bb --- /dev/null +++ b/src/werkzeug/datastructures/accept.py @@ -0,0 +1,326 @@ +from __future__ import annotations + +import codecs +import re + +from .structures import ImmutableList + + +class Accept(ImmutableList): + """An :class:`Accept` object is just a list subclass for lists of + ``(value, quality)`` tuples. It is automatically sorted by specificity + and quality. + + All :class:`Accept` objects work similar to a list but provide extra + functionality for working with the data. Containment checks are + normalized to the rules of that header: + + >>> a = CharsetAccept([('ISO-8859-1', 1), ('utf-8', 0.7)]) + >>> a.best + 'ISO-8859-1' + >>> 'iso-8859-1' in a + True + >>> 'UTF8' in a + True + >>> 'utf7' in a + False + + To get the quality for an item you can use normal item lookup: + + >>> print a['utf-8'] + 0.7 + >>> a['utf7'] + 0 + + .. versionchanged:: 0.5 + :class:`Accept` objects are forced immutable now. + + .. versionchanged:: 1.0.0 + :class:`Accept` internal values are no longer ordered + alphabetically for equal quality tags. Instead the initial + order is preserved. + + """ + + def __init__(self, values=()): + if values is None: + list.__init__(self) + self.provided = False + elif isinstance(values, Accept): + self.provided = values.provided + list.__init__(self, values) + else: + self.provided = True + values = sorted( + values, key=lambda x: (self._specificity(x[0]), x[1]), reverse=True + ) + list.__init__(self, values) + + def _specificity(self, value): + """Returns a tuple describing the value's specificity.""" + return (value != "*",) + + def _value_matches(self, value, item): + """Check if a value matches a given accept item.""" + return item == "*" or item.lower() == value.lower() + + def __getitem__(self, key): + """Besides index lookup (getting item n) you can also pass it a string + to get the quality for the item. If the item is not in the list, the + returned quality is ``0``. + """ + if isinstance(key, str): + return self.quality(key) + return list.__getitem__(self, key) + + def quality(self, key): + """Returns the quality of the key. + + .. versionadded:: 0.6 + In previous versions you had to use the item-lookup syntax + (eg: ``obj[key]`` instead of ``obj.quality(key)``) + """ + for item, quality in self: + if self._value_matches(key, item): + return quality + return 0 + + def __contains__(self, value): + for item, _quality in self: + if self._value_matches(value, item): + return True + return False + + def __repr__(self): + pairs_str = ", ".join(f"({x!r}, {y})" for x, y in self) + return f"{type(self).__name__}([{pairs_str}])" + + def index(self, key): + """Get the position of an entry or raise :exc:`ValueError`. + + :param key: The key to be looked up. + + .. versionchanged:: 0.5 + This used to raise :exc:`IndexError`, which was inconsistent + with the list API. + """ + if isinstance(key, str): + for idx, (item, _quality) in enumerate(self): + if self._value_matches(key, item): + return idx + raise ValueError(key) + return list.index(self, key) + + def find(self, key): + """Get the position of an entry or return -1. + + :param key: The key to be looked up. + """ + try: + return self.index(key) + except ValueError: + return -1 + + def values(self): + """Iterate over all values.""" + for item in self: + yield item[0] + + def to_header(self): + """Convert the header set into an HTTP header string.""" + result = [] + for value, quality in self: + if quality != 1: + value = f"{value};q={quality}" + result.append(value) + return ",".join(result) + + def __str__(self): + return self.to_header() + + def _best_single_match(self, match): + for client_item, quality in self: + if self._value_matches(match, client_item): + # self is sorted by specificity descending, we can exit + return client_item, quality + return None + + def best_match(self, matches, default=None): + """Returns the best match from a list of possible matches based + on the specificity and quality of the client. If two items have the + same quality and specificity, the one is returned that comes first. + + :param matches: a list of matches to check for + :param default: the value that is returned if none match + """ + result = default + best_quality = -1 + best_specificity = (-1,) + for server_item in matches: + match = self._best_single_match(server_item) + if not match: + continue + client_item, quality = match + specificity = self._specificity(client_item) + if quality <= 0 or quality < best_quality: + continue + # better quality or same quality but more specific => better match + if quality > best_quality or specificity > best_specificity: + result = server_item + best_quality = quality + best_specificity = specificity + return result + + @property + def best(self): + """The best match as value.""" + if self: + return self[0][0] + + +_mime_split_re = re.compile(r"/|(?:\s*;\s*)") + + +def _normalize_mime(value): + return _mime_split_re.split(value.lower()) + + +class MIMEAccept(Accept): + """Like :class:`Accept` but with special methods and behavior for + mimetypes. + """ + + def _specificity(self, value): + return tuple(x != "*" for x in _mime_split_re.split(value)) + + def _value_matches(self, value, item): + # item comes from the client, can't match if it's invalid. + if "/" not in item: + return False + + # value comes from the application, tell the developer when it + # doesn't look valid. + if "/" not in value: + raise ValueError(f"invalid mimetype {value!r}") + + # Split the match value into type, subtype, and a sorted list of parameters. + normalized_value = _normalize_mime(value) + value_type, value_subtype = normalized_value[:2] + value_params = sorted(normalized_value[2:]) + + # "*/*" is the only valid value that can start with "*". + if value_type == "*" and value_subtype != "*": + raise ValueError(f"invalid mimetype {value!r}") + + # Split the accept item into type, subtype, and parameters. + normalized_item = _normalize_mime(item) + item_type, item_subtype = normalized_item[:2] + item_params = sorted(normalized_item[2:]) + + # "*/not-*" from the client is invalid, can't match. + if item_type == "*" and item_subtype != "*": + return False + + return ( + (item_type == "*" and item_subtype == "*") + or (value_type == "*" and value_subtype == "*") + ) or ( + item_type == value_type + and ( + item_subtype == "*" + or value_subtype == "*" + or (item_subtype == value_subtype and item_params == value_params) + ) + ) + + @property + def accept_html(self): + """True if this object accepts HTML.""" + return ( + "text/html" in self or "application/xhtml+xml" in self or self.accept_xhtml + ) + + @property + def accept_xhtml(self): + """True if this object accepts XHTML.""" + return "application/xhtml+xml" in self or "application/xml" in self + + @property + def accept_json(self): + """True if this object accepts JSON.""" + return "application/json" in self + + +_locale_delim_re = re.compile(r"[_-]") + + +def _normalize_lang(value): + """Process a language tag for matching.""" + return _locale_delim_re.split(value.lower()) + + +class LanguageAccept(Accept): + """Like :class:`Accept` but with normalization for language tags.""" + + def _value_matches(self, value, item): + return item == "*" or _normalize_lang(value) == _normalize_lang(item) + + def best_match(self, matches, default=None): + """Given a list of supported values, finds the best match from + the list of accepted values. + + Language tags are normalized for the purpose of matching, but + are returned unchanged. + + If no exact match is found, this will fall back to matching + the first subtag (primary language only), first with the + accepted values then with the match values. This partial is not + applied to any other language subtags. + + The default is returned if no exact or fallback match is found. + + :param matches: A list of supported languages to find a match. + :param default: The value that is returned if none match. + """ + # Look for an exact match first. If a client accepts "en-US", + # "en-US" is a valid match at this point. + result = super().best_match(matches) + + if result is not None: + return result + + # Fall back to accepting primary tags. If a client accepts + # "en-US", "en" is a valid match at this point. Need to use + # re.split to account for 2 or 3 letter codes. + fallback = Accept( + [(_locale_delim_re.split(item[0], 1)[0], item[1]) for item in self] + ) + result = fallback.best_match(matches) + + if result is not None: + return result + + # Fall back to matching primary tags. If the client accepts + # "en", "en-US" is a valid match at this point. + fallback_matches = [_locale_delim_re.split(item, 1)[0] for item in matches] + result = super().best_match(fallback_matches) + + # Return a value from the original match list. Find the first + # original value that starts with the matched primary tag. + if result is not None: + return next(item for item in matches if item.startswith(result)) + + return default + + +class CharsetAccept(Accept): + """Like :class:`Accept` but with normalization for charsets.""" + + def _value_matches(self, value, item): + def _normalize(name): + try: + return codecs.lookup(name).name + except LookupError: + return name.lower() + + return item == "*" or _normalize(value) == _normalize(item) diff --git a/src/werkzeug/datastructures/accept.pyi b/src/werkzeug/datastructures/accept.pyi new file mode 100644 index 0000000..4b74dd9 --- /dev/null +++ b/src/werkzeug/datastructures/accept.pyi @@ -0,0 +1,54 @@ +from collections.abc import Iterable +from collections.abc import Iterator +from typing import overload + +from .structures import ImmutableList + +class Accept(ImmutableList[tuple[str, int]]): + provided: bool + def __init__( + self, values: Accept | Iterable[tuple[str, float]] | None = None + ) -> None: ... + def _specificity(self, value: str) -> tuple[bool, ...]: ... + def _value_matches(self, value: str, item: str) -> bool: ... + @overload # type: ignore + def __getitem__(self, key: str) -> int: ... + @overload + def __getitem__(self, key: int) -> tuple[str, int]: ... + @overload + def __getitem__(self, key: slice) -> Iterable[tuple[str, int]]: ... + def quality(self, key: str) -> int: ... + def __contains__(self, value: str) -> bool: ... # type: ignore + def index(self, key: str) -> int: ... # type: ignore + def find(self, key: str) -> int: ... + def values(self) -> Iterator[str]: ... + def to_header(self) -> str: ... + def _best_single_match(self, match: str) -> tuple[str, int] | None: ... + @overload + def best_match(self, matches: Iterable[str], default: str) -> str: ... + @overload + def best_match( + self, matches: Iterable[str], default: str | None = None + ) -> str | None: ... + @property + def best(self) -> str: ... + +def _normalize_mime(value: str) -> list[str]: ... + +class MIMEAccept(Accept): + def _specificity(self, value: str) -> tuple[bool, ...]: ... + def _value_matches(self, value: str, item: str) -> bool: ... + @property + def accept_html(self) -> bool: ... + @property + def accept_xhtml(self) -> bool: ... + @property + def accept_json(self) -> bool: ... + +def _normalize_lang(value: str) -> list[str]: ... + +class LanguageAccept(Accept): + def _value_matches(self, value: str, item: str) -> bool: ... + +class CharsetAccept(Accept): + def _value_matches(self, value: str, item: str) -> bool: ... diff --git a/src/werkzeug/datastructures/auth.py b/src/werkzeug/datastructures/auth.py new file mode 100644 index 0000000..494576d --- /dev/null +++ b/src/werkzeug/datastructures/auth.py @@ -0,0 +1,318 @@ +from __future__ import annotations + +import base64 +import binascii +import typing as t + +from ..http import dump_header +from ..http import parse_dict_header +from ..http import quote_header_value +from .structures import CallbackDict + +if t.TYPE_CHECKING: + import typing_extensions as te + + +class Authorization: + """Represents the parts of an ``Authorization`` request header. + + :attr:`.Request.authorization` returns an instance if the header is set. + + An instance can be used with the test :class:`.Client` request methods' ``auth`` + parameter to send the header in test requests. + + Depending on the auth scheme, either :attr:`parameters` or :attr:`token` will be + set. The ``Basic`` scheme's token is decoded into the ``username`` and ``password`` + parameters. + + For convenience, ``auth["key"]`` and ``auth.key`` both access the key in the + :attr:`parameters` dict, along with ``auth.get("key")`` and ``"key" in auth``. + + .. versionchanged:: 2.3 + The ``token`` parameter and attribute was added to support auth schemes that use + a token instead of parameters, such as ``Bearer``. + + .. versionchanged:: 2.3 + The object is no longer a ``dict``. + + .. versionchanged:: 0.5 + The object is an immutable dict. + """ + + def __init__( + self, + auth_type: str, + data: dict[str, str | None] | None = None, + token: str | None = None, + ) -> None: + self.type = auth_type + """The authorization scheme, like ``basic``, ``digest``, or ``bearer``.""" + + if data is None: + data = {} + + self.parameters = data + """A dict of parameters parsed from the header. Either this or :attr:`token` + will have a value for a given scheme. + """ + + self.token = token + """A token parsed from the header. Either this or :attr:`parameters` will have a + value for a given scheme. + + .. versionadded:: 2.3 + """ + + def __getattr__(self, name: str) -> str | None: + return self.parameters.get(name) + + def __getitem__(self, name: str) -> str | None: + return self.parameters.get(name) + + def get(self, key: str, default: str | None = None) -> str | None: + return self.parameters.get(key, default) + + def __contains__(self, key: str) -> bool: + return key in self.parameters + + def __eq__(self, other: object) -> bool: + if not isinstance(other, Authorization): + return NotImplemented + + return ( + other.type == self.type + and other.token == self.token + and other.parameters == self.parameters + ) + + @classmethod + def from_header(cls, value: str | None) -> te.Self | None: + """Parse an ``Authorization`` header value and return an instance, or ``None`` + if the value is empty. + + :param value: The header value to parse. + + .. versionadded:: 2.3 + """ + if not value: + return None + + scheme, _, rest = value.partition(" ") + scheme = scheme.lower() + rest = rest.strip() + + if scheme == "basic": + try: + username, _, password = base64.b64decode(rest).decode().partition(":") + except (binascii.Error, UnicodeError): + return None + + return cls(scheme, {"username": username, "password": password}) + + if "=" in rest.rstrip("="): + # = that is not trailing, this is parameters. + return cls(scheme, parse_dict_header(rest), None) + + # No = or only trailing =, this is a token. + return cls(scheme, None, rest) + + def to_header(self) -> str: + """Produce an ``Authorization`` header value representing this data. + + .. versionadded:: 2.0 + """ + if self.type == "basic": + value = base64.b64encode( + f"{self.username}:{self.password}".encode() + ).decode("utf8") + return f"Basic {value}" + + if self.token is not None: + return f"{self.type.title()} {self.token}" + + return f"{self.type.title()} {dump_header(self.parameters)}" + + def __str__(self) -> str: + return self.to_header() + + def __repr__(self) -> str: + return f"<{type(self).__name__} {self.to_header()}>" + + +class WWWAuthenticate: + """Represents the parts of a ``WWW-Authenticate`` response header. + + Set :attr:`.Response.www_authenticate` to an instance of list of instances to set + values for this header in the response. Modifying this instance will modify the + header value. + + Depending on the auth scheme, either :attr:`parameters` or :attr:`token` should be + set. The ``Basic`` scheme will encode ``username`` and ``password`` parameters to a + token. + + For convenience, ``auth["key"]`` and ``auth.key`` both act on the :attr:`parameters` + dict, and can be used to get, set, or delete parameters. ``auth.get("key")`` and + ``"key" in auth`` are also provided. + + .. versionchanged:: 2.3 + The ``token`` parameter and attribute was added to support auth schemes that use + a token instead of parameters, such as ``Bearer``. + + .. versionchanged:: 2.3 + The object is no longer a ``dict``. + + .. versionchanged:: 2.3 + The ``on_update`` parameter was removed. + """ + + def __init__( + self, + auth_type: str, + values: dict[str, str | None] | None = None, + token: str | None = None, + ): + self._type = auth_type.lower() + self._parameters: dict[str, str | None] = CallbackDict( # type: ignore[misc] + values, lambda _: self._trigger_on_update() + ) + self._token = token + self._on_update: t.Callable[[WWWAuthenticate], None] | None = None + + def _trigger_on_update(self) -> None: + if self._on_update is not None: + self._on_update(self) + + @property + def type(self) -> str: + """The authorization scheme, like ``basic``, ``digest``, or ``bearer``.""" + return self._type + + @type.setter + def type(self, value: str) -> None: + self._type = value + self._trigger_on_update() + + @property + def parameters(self) -> dict[str, str | None]: + """A dict of parameters for the header. Only one of this or :attr:`token` should + have a value for a given scheme. + """ + return self._parameters + + @parameters.setter + def parameters(self, value: dict[str, str]) -> None: + self._parameters = CallbackDict( # type: ignore[misc] + value, lambda _: self._trigger_on_update() + ) + self._trigger_on_update() + + @property + def token(self) -> str | None: + """A dict of parameters for the header. Only one of this or :attr:`token` should + have a value for a given scheme. + """ + return self._token + + @token.setter + def token(self, value: str | None) -> None: + """A token for the header. Only one of this or :attr:`parameters` should have a + value for a given scheme. + + .. versionadded:: 2.3 + """ + self._token = value + self._trigger_on_update() + + def __getitem__(self, key: str) -> str | None: + return self.parameters.get(key) + + def __setitem__(self, key: str, value: str | None) -> None: + if value is None: + if key in self.parameters: + del self.parameters[key] + else: + self.parameters[key] = value + + self._trigger_on_update() + + def __delitem__(self, key: str) -> None: + if key in self.parameters: + del self.parameters[key] + self._trigger_on_update() + + def __getattr__(self, name: str) -> str | None: + return self[name] + + def __setattr__(self, name: str, value: str | None) -> None: + if name in {"_type", "_parameters", "_token", "_on_update"}: + super().__setattr__(name, value) + else: + self[name] = value + + def __delattr__(self, name: str) -> None: + del self[name] + + def __contains__(self, key: str) -> bool: + return key in self.parameters + + def __eq__(self, other: object) -> bool: + if not isinstance(other, WWWAuthenticate): + return NotImplemented + + return ( + other.type == self.type + and other.token == self.token + and other.parameters == self.parameters + ) + + def get(self, key: str, default: str | None = None) -> str | None: + return self.parameters.get(key, default) + + @classmethod + def from_header(cls, value: str | None) -> te.Self | None: + """Parse a ``WWW-Authenticate`` header value and return an instance, or ``None`` + if the value is empty. + + :param value: The header value to parse. + + .. versionadded:: 2.3 + """ + if not value: + return None + + scheme, _, rest = value.partition(" ") + scheme = scheme.lower() + rest = rest.strip() + + if "=" in rest.rstrip("="): + # = that is not trailing, this is parameters. + return cls(scheme, parse_dict_header(rest), None) + + # No = or only trailing =, this is a token. + return cls(scheme, None, rest) + + def to_header(self) -> str: + """Produce a ``WWW-Authenticate`` header value representing this data.""" + if self.token is not None: + return f"{self.type.title()} {self.token}" + + if self.type == "digest": + items = [] + + for key, value in self.parameters.items(): + if key in {"realm", "domain", "nonce", "opaque", "qop"}: + value = quote_header_value(value, allow_token=False) + else: + value = quote_header_value(value) + + items.append(f"{key}={value}") + + return f"Digest {', '.join(items)}" + + return f"{self.type.title()} {dump_header(self.parameters)}" + + def __str__(self) -> str: + return self.to_header() + + def __repr__(self) -> str: + return f"<{type(self).__name__} {self.to_header()}>" diff --git a/src/werkzeug/datastructures/cache_control.py b/src/werkzeug/datastructures/cache_control.py new file mode 100644 index 0000000..bff4c18 --- /dev/null +++ b/src/werkzeug/datastructures/cache_control.py @@ -0,0 +1,175 @@ +from __future__ import annotations + +from .mixins import ImmutableDictMixin +from .mixins import UpdateDictMixin + + +def cache_control_property(key, empty, type): + """Return a new property object for a cache header. Useful if you + want to add support for a cache extension in a subclass. + + .. versionchanged:: 2.0 + Renamed from ``cache_property``. + """ + return property( + lambda x: x._get_cache_value(key, empty, type), + lambda x, v: x._set_cache_value(key, v, type), + lambda x: x._del_cache_value(key), + f"accessor for {key!r}", + ) + + +class _CacheControl(UpdateDictMixin, dict): + """Subclass of a dict that stores values for a Cache-Control header. It + has accessors for all the cache-control directives specified in RFC 2616. + The class does not differentiate between request and response directives. + + Because the cache-control directives in the HTTP header use dashes the + python descriptors use underscores for that. + + To get a header of the :class:`CacheControl` object again you can convert + the object into a string or call the :meth:`to_header` method. If you plan + to subclass it and add your own items have a look at the sourcecode for + that class. + + .. versionchanged:: 2.1.0 + Setting int properties such as ``max_age`` will convert the + value to an int. + + .. versionchanged:: 0.4 + + Setting `no_cache` or `private` to boolean `True` will set the implicit + none-value which is ``*``: + + >>> cc = ResponseCacheControl() + >>> cc.no_cache = True + >>> cc + + >>> cc.no_cache + '*' + >>> cc.no_cache = None + >>> cc + + + In versions before 0.5 the behavior documented here affected the now + no longer existing `CacheControl` class. + """ + + no_cache = cache_control_property("no-cache", "*", None) + no_store = cache_control_property("no-store", None, bool) + max_age = cache_control_property("max-age", -1, int) + no_transform = cache_control_property("no-transform", None, None) + + def __init__(self, values=(), on_update=None): + dict.__init__(self, values or ()) + self.on_update = on_update + self.provided = values is not None + + def _get_cache_value(self, key, empty, type): + """Used internally by the accessor properties.""" + if type is bool: + return key in self + if key in self: + value = self[key] + if value is None: + return empty + elif type is not None: + try: + value = type(value) + except ValueError: + pass + return value + return None + + def _set_cache_value(self, key, value, type): + """Used internally by the accessor properties.""" + if type is bool: + if value: + self[key] = None + else: + self.pop(key, None) + else: + if value is None: + self.pop(key, None) + elif value is True: + self[key] = None + else: + if type is not None: + self[key] = type(value) + else: + self[key] = value + + def _del_cache_value(self, key): + """Used internally by the accessor properties.""" + if key in self: + del self[key] + + def to_header(self): + """Convert the stored values into a cache control header.""" + return http.dump_header(self) + + def __str__(self): + return self.to_header() + + def __repr__(self): + kv_str = " ".join(f"{k}={v!r}" for k, v in sorted(self.items())) + return f"<{type(self).__name__} {kv_str}>" + + cache_property = staticmethod(cache_control_property) + + +class RequestCacheControl(ImmutableDictMixin, _CacheControl): + """A cache control for requests. This is immutable and gives access + to all the request-relevant cache control headers. + + To get a header of the :class:`RequestCacheControl` object again you can + convert the object into a string or call the :meth:`to_header` method. If + you plan to subclass it and add your own items have a look at the sourcecode + for that class. + + .. versionchanged:: 2.1.0 + Setting int properties such as ``max_age`` will convert the + value to an int. + + .. versionadded:: 0.5 + In previous versions a `CacheControl` class existed that was used + both for request and response. + """ + + max_stale = cache_control_property("max-stale", "*", int) + min_fresh = cache_control_property("min-fresh", "*", int) + only_if_cached = cache_control_property("only-if-cached", None, bool) + + +class ResponseCacheControl(_CacheControl): + """A cache control for responses. Unlike :class:`RequestCacheControl` + this is mutable and gives access to response-relevant cache control + headers. + + To get a header of the :class:`ResponseCacheControl` object again you can + convert the object into a string or call the :meth:`to_header` method. If + you plan to subclass it and add your own items have a look at the sourcecode + for that class. + + .. versionchanged:: 2.1.1 + ``s_maxage`` converts the value to an int. + + .. versionchanged:: 2.1.0 + Setting int properties such as ``max_age`` will convert the + value to an int. + + .. versionadded:: 0.5 + In previous versions a `CacheControl` class existed that was used + both for request and response. + """ + + public = cache_control_property("public", None, bool) + private = cache_control_property("private", "*", None) + must_revalidate = cache_control_property("must-revalidate", None, bool) + proxy_revalidate = cache_control_property("proxy-revalidate", None, bool) + s_maxage = cache_control_property("s-maxage", None, int) + immutable = cache_control_property("immutable", None, bool) + + +# circular dependencies +from .. import http diff --git a/src/werkzeug/datastructures/cache_control.pyi b/src/werkzeug/datastructures/cache_control.pyi new file mode 100644 index 0000000..06fe667 --- /dev/null +++ b/src/werkzeug/datastructures/cache_control.pyi @@ -0,0 +1,109 @@ +from collections.abc import Callable +from collections.abc import Iterable +from collections.abc import Mapping +from typing import TypeVar + +from .mixins import ImmutableDictMixin +from .mixins import UpdateDictMixin + +T = TypeVar("T") +_CPT = TypeVar("_CPT", str, int, bool) +_OptCPT = _CPT | None + +def cache_control_property(key: str, empty: _OptCPT, type: type[_CPT]) -> property: ... + +class _CacheControl(UpdateDictMixin[str, _OptCPT], dict[str, _OptCPT]): + provided: bool + def __init__( + self, + values: Mapping[str, _OptCPT] | Iterable[tuple[str, _OptCPT]] = (), + on_update: Callable[[_CacheControl], None] | None = None, + ) -> None: ... + @property + def no_cache(self) -> bool | None: ... + @no_cache.setter + def no_cache(self, value: bool | None) -> None: ... + @no_cache.deleter + def no_cache(self) -> None: ... + @property + def no_store(self) -> bool | None: ... + @no_store.setter + def no_store(self, value: bool | None) -> None: ... + @no_store.deleter + def no_store(self) -> None: ... + @property + def max_age(self) -> int | None: ... + @max_age.setter + def max_age(self, value: int | None) -> None: ... + @max_age.deleter + def max_age(self) -> None: ... + @property + def no_transform(self) -> bool | None: ... + @no_transform.setter + def no_transform(self, value: bool | None) -> None: ... + @no_transform.deleter + def no_transform(self) -> None: ... + def _get_cache_value(self, key: str, empty: T | None, type: type[T]) -> T: ... + def _set_cache_value(self, key: str, value: T | None, type: type[T]) -> None: ... + def _del_cache_value(self, key: str) -> None: ... + def to_header(self) -> str: ... + @staticmethod + def cache_property(key: str, empty: _OptCPT, type: type[_CPT]) -> property: ... + +class RequestCacheControl(ImmutableDictMixin[str, _OptCPT], _CacheControl): + @property + def max_stale(self) -> int | None: ... + @max_stale.setter + def max_stale(self, value: int | None) -> None: ... + @max_stale.deleter + def max_stale(self) -> None: ... + @property + def min_fresh(self) -> int | None: ... + @min_fresh.setter + def min_fresh(self, value: int | None) -> None: ... + @min_fresh.deleter + def min_fresh(self) -> None: ... + @property + def only_if_cached(self) -> bool | None: ... + @only_if_cached.setter + def only_if_cached(self, value: bool | None) -> None: ... + @only_if_cached.deleter + def only_if_cached(self) -> None: ... + +class ResponseCacheControl(_CacheControl): + @property + def public(self) -> bool | None: ... + @public.setter + def public(self, value: bool | None) -> None: ... + @public.deleter + def public(self) -> None: ... + @property + def private(self) -> bool | None: ... + @private.setter + def private(self, value: bool | None) -> None: ... + @private.deleter + def private(self) -> None: ... + @property + def must_revalidate(self) -> bool | None: ... + @must_revalidate.setter + def must_revalidate(self, value: bool | None) -> None: ... + @must_revalidate.deleter + def must_revalidate(self) -> None: ... + @property + def proxy_revalidate(self) -> bool | None: ... + @proxy_revalidate.setter + def proxy_revalidate(self, value: bool | None) -> None: ... + @proxy_revalidate.deleter + def proxy_revalidate(self) -> None: ... + @property + def s_maxage(self) -> int | None: ... + @s_maxage.setter + def s_maxage(self, value: int | None) -> None: ... + @s_maxage.deleter + def s_maxage(self) -> None: ... + @property + def immutable(self) -> bool | None: ... + @immutable.setter + def immutable(self, value: bool | None) -> None: ... + @immutable.deleter + def immutable(self) -> None: ... diff --git a/src/werkzeug/datastructures/csp.py b/src/werkzeug/datastructures/csp.py new file mode 100644 index 0000000..dde9414 --- /dev/null +++ b/src/werkzeug/datastructures/csp.py @@ -0,0 +1,94 @@ +from __future__ import annotations + +from .mixins import UpdateDictMixin + + +def csp_property(key): + """Return a new property object for a content security policy header. + Useful if you want to add support for a csp extension in a + subclass. + """ + return property( + lambda x: x._get_value(key), + lambda x, v: x._set_value(key, v), + lambda x: x._del_value(key), + f"accessor for {key!r}", + ) + + +class ContentSecurityPolicy(UpdateDictMixin, dict): + """Subclass of a dict that stores values for a Content Security Policy + header. It has accessors for all the level 3 policies. + + Because the csp directives in the HTTP header use dashes the + python descriptors use underscores for that. + + To get a header of the :class:`ContentSecuirtyPolicy` object again + you can convert the object into a string or call the + :meth:`to_header` method. If you plan to subclass it and add your + own items have a look at the sourcecode for that class. + + .. versionadded:: 1.0.0 + Support for Content Security Policy headers was added. + + """ + + base_uri = csp_property("base-uri") + child_src = csp_property("child-src") + connect_src = csp_property("connect-src") + default_src = csp_property("default-src") + font_src = csp_property("font-src") + form_action = csp_property("form-action") + frame_ancestors = csp_property("frame-ancestors") + frame_src = csp_property("frame-src") + img_src = csp_property("img-src") + manifest_src = csp_property("manifest-src") + media_src = csp_property("media-src") + navigate_to = csp_property("navigate-to") + object_src = csp_property("object-src") + prefetch_src = csp_property("prefetch-src") + plugin_types = csp_property("plugin-types") + report_to = csp_property("report-to") + report_uri = csp_property("report-uri") + sandbox = csp_property("sandbox") + script_src = csp_property("script-src") + script_src_attr = csp_property("script-src-attr") + script_src_elem = csp_property("script-src-elem") + style_src = csp_property("style-src") + style_src_attr = csp_property("style-src-attr") + style_src_elem = csp_property("style-src-elem") + worker_src = csp_property("worker-src") + + def __init__(self, values=(), on_update=None): + dict.__init__(self, values or ()) + self.on_update = on_update + self.provided = values is not None + + def _get_value(self, key): + """Used internally by the accessor properties.""" + return self.get(key) + + def _set_value(self, key, value): + """Used internally by the accessor properties.""" + if value is None: + self.pop(key, None) + else: + self[key] = value + + def _del_value(self, key): + """Used internally by the accessor properties.""" + if key in self: + del self[key] + + def to_header(self): + """Convert the stored values into a cache control header.""" + from ..http import dump_csp_header + + return dump_csp_header(self) + + def __str__(self): + return self.to_header() + + def __repr__(self): + kv_str = " ".join(f"{k}={v!r}" for k, v in sorted(self.items())) + return f"<{type(self).__name__} {kv_str}>" diff --git a/src/werkzeug/datastructures/csp.pyi b/src/werkzeug/datastructures/csp.pyi new file mode 100644 index 0000000..f9e2ac0 --- /dev/null +++ b/src/werkzeug/datastructures/csp.pyi @@ -0,0 +1,169 @@ +from collections.abc import Callable +from collections.abc import Iterable +from collections.abc import Mapping + +from .mixins import UpdateDictMixin + +def csp_property(key: str) -> property: ... + +class ContentSecurityPolicy(UpdateDictMixin[str, str], dict[str, str]): + @property + def base_uri(self) -> str | None: ... + @base_uri.setter + def base_uri(self, value: str | None) -> None: ... + @base_uri.deleter + def base_uri(self) -> None: ... + @property + def child_src(self) -> str | None: ... + @child_src.setter + def child_src(self, value: str | None) -> None: ... + @child_src.deleter + def child_src(self) -> None: ... + @property + def connect_src(self) -> str | None: ... + @connect_src.setter + def connect_src(self, value: str | None) -> None: ... + @connect_src.deleter + def connect_src(self) -> None: ... + @property + def default_src(self) -> str | None: ... + @default_src.setter + def default_src(self, value: str | None) -> None: ... + @default_src.deleter + def default_src(self) -> None: ... + @property + def font_src(self) -> str | None: ... + @font_src.setter + def font_src(self, value: str | None) -> None: ... + @font_src.deleter + def font_src(self) -> None: ... + @property + def form_action(self) -> str | None: ... + @form_action.setter + def form_action(self, value: str | None) -> None: ... + @form_action.deleter + def form_action(self) -> None: ... + @property + def frame_ancestors(self) -> str | None: ... + @frame_ancestors.setter + def frame_ancestors(self, value: str | None) -> None: ... + @frame_ancestors.deleter + def frame_ancestors(self) -> None: ... + @property + def frame_src(self) -> str | None: ... + @frame_src.setter + def frame_src(self, value: str | None) -> None: ... + @frame_src.deleter + def frame_src(self) -> None: ... + @property + def img_src(self) -> str | None: ... + @img_src.setter + def img_src(self, value: str | None) -> None: ... + @img_src.deleter + def img_src(self) -> None: ... + @property + def manifest_src(self) -> str | None: ... + @manifest_src.setter + def manifest_src(self, value: str | None) -> None: ... + @manifest_src.deleter + def manifest_src(self) -> None: ... + @property + def media_src(self) -> str | None: ... + @media_src.setter + def media_src(self, value: str | None) -> None: ... + @media_src.deleter + def media_src(self) -> None: ... + @property + def navigate_to(self) -> str | None: ... + @navigate_to.setter + def navigate_to(self, value: str | None) -> None: ... + @navigate_to.deleter + def navigate_to(self) -> None: ... + @property + def object_src(self) -> str | None: ... + @object_src.setter + def object_src(self, value: str | None) -> None: ... + @object_src.deleter + def object_src(self) -> None: ... + @property + def prefetch_src(self) -> str | None: ... + @prefetch_src.setter + def prefetch_src(self, value: str | None) -> None: ... + @prefetch_src.deleter + def prefetch_src(self) -> None: ... + @property + def plugin_types(self) -> str | None: ... + @plugin_types.setter + def plugin_types(self, value: str | None) -> None: ... + @plugin_types.deleter + def plugin_types(self) -> None: ... + @property + def report_to(self) -> str | None: ... + @report_to.setter + def report_to(self, value: str | None) -> None: ... + @report_to.deleter + def report_to(self) -> None: ... + @property + def report_uri(self) -> str | None: ... + @report_uri.setter + def report_uri(self, value: str | None) -> None: ... + @report_uri.deleter + def report_uri(self) -> None: ... + @property + def sandbox(self) -> str | None: ... + @sandbox.setter + def sandbox(self, value: str | None) -> None: ... + @sandbox.deleter + def sandbox(self) -> None: ... + @property + def script_src(self) -> str | None: ... + @script_src.setter + def script_src(self, value: str | None) -> None: ... + @script_src.deleter + def script_src(self) -> None: ... + @property + def script_src_attr(self) -> str | None: ... + @script_src_attr.setter + def script_src_attr(self, value: str | None) -> None: ... + @script_src_attr.deleter + def script_src_attr(self) -> None: ... + @property + def script_src_elem(self) -> str | None: ... + @script_src_elem.setter + def script_src_elem(self, value: str | None) -> None: ... + @script_src_elem.deleter + def script_src_elem(self) -> None: ... + @property + def style_src(self) -> str | None: ... + @style_src.setter + def style_src(self, value: str | None) -> None: ... + @style_src.deleter + def style_src(self) -> None: ... + @property + def style_src_attr(self) -> str | None: ... + @style_src_attr.setter + def style_src_attr(self, value: str | None) -> None: ... + @style_src_attr.deleter + def style_src_attr(self) -> None: ... + @property + def style_src_elem(self) -> str | None: ... + @style_src_elem.setter + def style_src_elem(self, value: str | None) -> None: ... + @style_src_elem.deleter + def style_src_elem(self) -> None: ... + @property + def worker_src(self) -> str | None: ... + @worker_src.setter + def worker_src(self, value: str | None) -> None: ... + @worker_src.deleter + def worker_src(self) -> None: ... + provided: bool + def __init__( + self, + values: Mapping[str, str] | Iterable[tuple[str, str]] = (), + on_update: Callable[[ContentSecurityPolicy], None] | None = None, + ) -> None: ... + def _get_value(self, key: str) -> str | None: ... + def _set_value(self, key: str, value: str) -> None: ... + def _del_value(self, key: str) -> None: ... + def to_header(self) -> str: ... diff --git a/src/werkzeug/datastructures/etag.py b/src/werkzeug/datastructures/etag.py new file mode 100644 index 0000000..747d996 --- /dev/null +++ b/src/werkzeug/datastructures/etag.py @@ -0,0 +1,95 @@ +from __future__ import annotations + +from collections.abc import Collection + + +class ETags(Collection): + """A set that can be used to check if one etag is present in a collection + of etags. + """ + + def __init__(self, strong_etags=None, weak_etags=None, star_tag=False): + if not star_tag and strong_etags: + self._strong = frozenset(strong_etags) + else: + self._strong = frozenset() + + self._weak = frozenset(weak_etags or ()) + self.star_tag = star_tag + + def as_set(self, include_weak=False): + """Convert the `ETags` object into a python set. Per default all the + weak etags are not part of this set.""" + rv = set(self._strong) + if include_weak: + rv.update(self._weak) + return rv + + def is_weak(self, etag): + """Check if an etag is weak.""" + return etag in self._weak + + def is_strong(self, etag): + """Check if an etag is strong.""" + return etag in self._strong + + def contains_weak(self, etag): + """Check if an etag is part of the set including weak and strong tags.""" + return self.is_weak(etag) or self.contains(etag) + + def contains(self, etag): + """Check if an etag is part of the set ignoring weak tags. + It is also possible to use the ``in`` operator. + """ + if self.star_tag: + return True + return self.is_strong(etag) + + def contains_raw(self, etag): + """When passed a quoted tag it will check if this tag is part of the + set. If the tag is weak it is checked against weak and strong tags, + otherwise strong only.""" + from ..http import unquote_etag + + etag, weak = unquote_etag(etag) + if weak: + return self.contains_weak(etag) + return self.contains(etag) + + def to_header(self): + """Convert the etags set into a HTTP header string.""" + if self.star_tag: + return "*" + return ", ".join( + [f'"{x}"' for x in self._strong] + [f'W/"{x}"' for x in self._weak] + ) + + def __call__(self, etag=None, data=None, include_weak=False): + if [etag, data].count(None) != 1: + raise TypeError("either tag or data required, but at least one") + if etag is None: + from ..http import generate_etag + + etag = generate_etag(data) + if include_weak: + if etag in self._weak: + return True + return etag in self._strong + + def __bool__(self): + return bool(self.star_tag or self._strong or self._weak) + + def __str__(self): + return self.to_header() + + def __len__(self): + return len(self._strong) + + def __iter__(self): + return iter(self._strong) + + def __contains__(self, etag): + return self.contains(etag) + + def __repr__(self): + return f"<{type(self).__name__} {str(self)!r}>" diff --git a/src/werkzeug/datastructures/etag.pyi b/src/werkzeug/datastructures/etag.pyi new file mode 100644 index 0000000..88e54f1 --- /dev/null +++ b/src/werkzeug/datastructures/etag.pyi @@ -0,0 +1,30 @@ +from collections.abc import Collection +from collections.abc import Iterable +from collections.abc import Iterator + +class ETags(Collection[str]): + _strong: frozenset[str] + _weak: frozenset[str] + star_tag: bool + def __init__( + self, + strong_etags: Iterable[str] | None = None, + weak_etags: Iterable[str] | None = None, + star_tag: bool = False, + ) -> None: ... + def as_set(self, include_weak: bool = False) -> set[str]: ... + def is_weak(self, etag: str) -> bool: ... + def is_strong(self, etag: str) -> bool: ... + def contains_weak(self, etag: str) -> bool: ... + def contains(self, etag: str) -> bool: ... + def contains_raw(self, etag: str) -> bool: ... + def to_header(self) -> str: ... + def __call__( + self, + etag: str | None = None, + data: bytes | None = None, + include_weak: bool = False, + ) -> bool: ... + def __len__(self) -> int: ... + def __iter__(self) -> Iterator[str]: ... + def __contains__(self, item: str) -> bool: ... # type: ignore diff --git a/src/werkzeug/datastructures/file_storage.py b/src/werkzeug/datastructures/file_storage.py new file mode 100644 index 0000000..e878a56 --- /dev/null +++ b/src/werkzeug/datastructures/file_storage.py @@ -0,0 +1,196 @@ +from __future__ import annotations + +import mimetypes +from io import BytesIO +from os import fsdecode +from os import fspath + +from .._internal import _plain_int +from .structures import MultiDict + + +class FileStorage: + """The :class:`FileStorage` class is a thin wrapper over incoming files. + It is used by the request object to represent uploaded files. All the + attributes of the wrapper stream are proxied by the file storage so + it's possible to do ``storage.read()`` instead of the long form + ``storage.stream.read()``. + """ + + def __init__( + self, + stream=None, + filename=None, + name=None, + content_type=None, + content_length=None, + headers=None, + ): + self.name = name + self.stream = stream or BytesIO() + + # If no filename is provided, attempt to get the filename from + # the stream object. Python names special streams like + # ```` with angular brackets, skip these streams. + if filename is None: + filename = getattr(stream, "name", None) + + if filename is not None: + filename = fsdecode(filename) + + if filename and filename[0] == "<" and filename[-1] == ">": + filename = None + else: + filename = fsdecode(filename) + + self.filename = filename + + if headers is None: + from .headers import Headers + + headers = Headers() + self.headers = headers + if content_type is not None: + headers["Content-Type"] = content_type + if content_length is not None: + headers["Content-Length"] = str(content_length) + + def _parse_content_type(self): + if not hasattr(self, "_parsed_content_type"): + self._parsed_content_type = http.parse_options_header(self.content_type) + + @property + def content_type(self): + """The content-type sent in the header. Usually not available""" + return self.headers.get("content-type") + + @property + def content_length(self): + """The content-length sent in the header. Usually not available""" + if "content-length" in self.headers: + try: + return _plain_int(self.headers["content-length"]) + except ValueError: + pass + + return 0 + + @property + def mimetype(self): + """Like :attr:`content_type`, but without parameters (eg, without + charset, type etc.) and always lowercase. For example if the content + type is ``text/HTML; charset=utf-8`` the mimetype would be + ``'text/html'``. + + .. versionadded:: 0.7 + """ + self._parse_content_type() + return self._parsed_content_type[0].lower() + + @property + def mimetype_params(self): + """The mimetype parameters as dict. For example if the content + type is ``text/html; charset=utf-8`` the params would be + ``{'charset': 'utf-8'}``. + + .. versionadded:: 0.7 + """ + self._parse_content_type() + return self._parsed_content_type[1] + + def save(self, dst, buffer_size=16384): + """Save the file to a destination path or file object. If the + destination is a file object you have to close it yourself after the + call. The buffer size is the number of bytes held in memory during + the copy process. It defaults to 16KB. + + For secure file saving also have a look at :func:`secure_filename`. + + :param dst: a filename, :class:`os.PathLike`, or open file + object to write to. + :param buffer_size: Passed as the ``length`` parameter of + :func:`shutil.copyfileobj`. + + .. versionchanged:: 1.0 + Supports :mod:`pathlib`. + """ + from shutil import copyfileobj + + close_dst = False + + if hasattr(dst, "__fspath__"): + dst = fspath(dst) + + if isinstance(dst, str): + dst = open(dst, "wb") + close_dst = True + + try: + copyfileobj(self.stream, dst, buffer_size) + finally: + if close_dst: + dst.close() + + def close(self): + """Close the underlying file if possible.""" + try: + self.stream.close() + except Exception: + pass + + def __bool__(self): + return bool(self.filename) + + def __getattr__(self, name): + try: + return getattr(self.stream, name) + except AttributeError: + # SpooledTemporaryFile doesn't implement IOBase, get the + # attribute from its backing file instead. + # https://github.com/python/cpython/pull/3249 + if hasattr(self.stream, "_file"): + return getattr(self.stream._file, name) + raise + + def __iter__(self): + return iter(self.stream) + + def __repr__(self): + return f"<{type(self).__name__}: {self.filename!r} ({self.content_type!r})>" + + +class FileMultiDict(MultiDict): + """A special :class:`MultiDict` that has convenience methods to add + files to it. This is used for :class:`EnvironBuilder` and generally + useful for unittesting. + + .. versionadded:: 0.5 + """ + + def add_file(self, name, file, filename=None, content_type=None): + """Adds a new file to the dict. `file` can be a file name or + a :class:`file`-like or a :class:`FileStorage` object. + + :param name: the name of the field. + :param file: a filename or :class:`file`-like object + :param filename: an optional filename + :param content_type: an optional content type + """ + if isinstance(file, FileStorage): + value = file + else: + if isinstance(file, str): + if filename is None: + filename = file + file = open(file, "rb") + if filename and content_type is None: + content_type = ( + mimetypes.guess_type(filename)[0] or "application/octet-stream" + ) + value = FileStorage(file, filename, name, content_type) + + self.add(name, value) + + +# circular dependencies +from .. import http diff --git a/src/werkzeug/datastructures/file_storage.pyi b/src/werkzeug/datastructures/file_storage.pyi new file mode 100644 index 0000000..730789e --- /dev/null +++ b/src/werkzeug/datastructures/file_storage.pyi @@ -0,0 +1,47 @@ +from collections.abc import Iterator +from os import PathLike +from typing import Any +from typing import IO + +from .headers import Headers +from .structures import MultiDict + +class FileStorage: + name: str | None + stream: IO[bytes] + filename: str | None + headers: Headers + _parsed_content_type: tuple[str, dict[str, str]] + def __init__( + self, + stream: IO[bytes] | None = None, + filename: str | PathLike | None = None, + name: str | None = None, + content_type: str | None = None, + content_length: int | None = None, + headers: Headers | None = None, + ) -> None: ... + def _parse_content_type(self) -> None: ... + @property + def content_type(self) -> str: ... + @property + def content_length(self) -> int: ... + @property + def mimetype(self) -> str: ... + @property + def mimetype_params(self) -> dict[str, str]: ... + def save(self, dst: str | PathLike | IO[bytes], buffer_size: int = ...) -> None: ... + def close(self) -> None: ... + def __bool__(self) -> bool: ... + def __getattr__(self, name: str) -> Any: ... + def __iter__(self) -> Iterator[bytes]: ... + def __repr__(self) -> str: ... + +class FileMultiDict(MultiDict[str, FileStorage]): + def add_file( + self, + name: str, + file: FileStorage | str | IO[bytes], + filename: str | None = None, + content_type: str | None = None, + ) -> None: ... diff --git a/src/werkzeug/datastructures/headers.py b/src/werkzeug/datastructures/headers.py new file mode 100644 index 0000000..d9dd655 --- /dev/null +++ b/src/werkzeug/datastructures/headers.py @@ -0,0 +1,515 @@ +from __future__ import annotations + +import re +import typing as t + +from .._internal import _missing +from ..exceptions import BadRequestKeyError +from .mixins import ImmutableHeadersMixin +from .structures import iter_multi_items +from .structures import MultiDict + + +class Headers: + """An object that stores some headers. It has a dict-like interface, + but is ordered, can store the same key multiple times, and iterating + yields ``(key, value)`` pairs instead of only keys. + + This data structure is useful if you want a nicer way to handle WSGI + headers which are stored as tuples in a list. + + From Werkzeug 0.3 onwards, the :exc:`KeyError` raised by this class is + also a subclass of the :class:`~exceptions.BadRequest` HTTP exception + and will render a page for a ``400 BAD REQUEST`` if caught in a + catch-all for HTTP exceptions. + + Headers is mostly compatible with the Python :class:`wsgiref.headers.Headers` + class, with the exception of `__getitem__`. :mod:`wsgiref` will return + `None` for ``headers['missing']``, whereas :class:`Headers` will raise + a :class:`KeyError`. + + To create a new ``Headers`` object, pass it a list, dict, or + other ``Headers`` object with default values. These values are + validated the same way values added later are. + + :param defaults: The list of default values for the :class:`Headers`. + + .. versionchanged:: 2.1.0 + Default values are validated the same as values added later. + + .. versionchanged:: 0.9 + This data structure now stores unicode values similar to how the + multi dicts do it. The main difference is that bytes can be set as + well which will automatically be latin1 decoded. + + .. versionchanged:: 0.9 + The :meth:`linked` function was removed without replacement as it + was an API that does not support the changes to the encoding model. + """ + + def __init__(self, defaults=None): + self._list = [] + if defaults is not None: + self.extend(defaults) + + def __getitem__(self, key, _get_mode=False): + if not _get_mode: + if isinstance(key, int): + return self._list[key] + elif isinstance(key, slice): + return self.__class__(self._list[key]) + if not isinstance(key, str): + raise BadRequestKeyError(key) + ikey = key.lower() + for k, v in self._list: + if k.lower() == ikey: + return v + # micro optimization: if we are in get mode we will catch that + # exception one stack level down so we can raise a standard + # key error instead of our special one. + if _get_mode: + raise KeyError() + raise BadRequestKeyError(key) + + def __eq__(self, other): + def lowered(item): + return (item[0].lower(),) + item[1:] + + return other.__class__ is self.__class__ and set( + map(lowered, other._list) + ) == set(map(lowered, self._list)) + + __hash__ = None + + def get(self, key, default=None, type=None): + """Return the default value if the requested data doesn't exist. + If `type` is provided and is a callable it should convert the value, + return it or raise a :exc:`ValueError` if that is not possible. In + this case the function will return the default as if the value was not + found: + + >>> d = Headers([('Content-Length', '42')]) + >>> d.get('Content-Length', type=int) + 42 + + :param key: The key to be looked up. + :param default: The default value to be returned if the key can't + be looked up. If not further specified `None` is + returned. + :param type: A callable that is used to cast the value in the + :class:`Headers`. If a :exc:`ValueError` is raised + by this callable the default value is returned. + + .. versionchanged:: 3.0 + The ``as_bytes`` parameter was removed. + + .. versionchanged:: 0.9 + The ``as_bytes`` parameter was added. + """ + try: + rv = self.__getitem__(key, _get_mode=True) + except KeyError: + return default + if type is None: + return rv + try: + return type(rv) + except ValueError: + return default + + def getlist(self, key, type=None): + """Return the list of items for a given key. If that key is not in the + :class:`Headers`, the return value will be an empty list. Just like + :meth:`get`, :meth:`getlist` accepts a `type` parameter. All items will + be converted with the callable defined there. + + :param key: The key to be looked up. + :param type: A callable that is used to cast the value in the + :class:`Headers`. If a :exc:`ValueError` is raised + by this callable the value will be removed from the list. + :return: a :class:`list` of all the values for the key. + + .. versionchanged:: 3.0 + The ``as_bytes`` parameter was removed. + + .. versionchanged:: 0.9 + The ``as_bytes`` parameter was added. + """ + ikey = key.lower() + result = [] + for k, v in self: + if k.lower() == ikey: + if type is not None: + try: + v = type(v) + except ValueError: + continue + result.append(v) + return result + + def get_all(self, name): + """Return a list of all the values for the named field. + + This method is compatible with the :mod:`wsgiref` + :meth:`~wsgiref.headers.Headers.get_all` method. + """ + return self.getlist(name) + + def items(self, lower=False): + for key, value in self: + if lower: + key = key.lower() + yield key, value + + def keys(self, lower=False): + for key, _ in self.items(lower): + yield key + + def values(self): + for _, value in self.items(): + yield value + + def extend(self, *args, **kwargs): + """Extend headers in this object with items from another object + containing header items as well as keyword arguments. + + To replace existing keys instead of extending, use + :meth:`update` instead. + + If provided, the first argument can be another :class:`Headers` + object, a :class:`MultiDict`, :class:`dict`, or iterable of + pairs. + + .. versionchanged:: 1.0 + Support :class:`MultiDict`. Allow passing ``kwargs``. + """ + if len(args) > 1: + raise TypeError(f"update expected at most 1 arguments, got {len(args)}") + + if args: + for key, value in iter_multi_items(args[0]): + self.add(key, value) + + for key, value in iter_multi_items(kwargs): + self.add(key, value) + + def __delitem__(self, key, _index_operation=True): + if _index_operation and isinstance(key, (int, slice)): + del self._list[key] + return + key = key.lower() + new = [] + for k, v in self._list: + if k.lower() != key: + new.append((k, v)) + self._list[:] = new + + def remove(self, key): + """Remove a key. + + :param key: The key to be removed. + """ + return self.__delitem__(key, _index_operation=False) + + def pop(self, key=None, default=_missing): + """Removes and returns a key or index. + + :param key: The key to be popped. If this is an integer the item at + that position is removed, if it's a string the value for + that key is. If the key is omitted or `None` the last + item is removed. + :return: an item. + """ + if key is None: + return self._list.pop() + if isinstance(key, int): + return self._list.pop(key) + try: + rv = self[key] + self.remove(key) + except KeyError: + if default is not _missing: + return default + raise + return rv + + def popitem(self): + """Removes a key or index and returns a (key, value) item.""" + return self.pop() + + def __contains__(self, key): + """Check if a key is present.""" + try: + self.__getitem__(key, _get_mode=True) + except KeyError: + return False + return True + + def __iter__(self): + """Yield ``(key, value)`` tuples.""" + return iter(self._list) + + def __len__(self): + return len(self._list) + + def add(self, _key, _value, **kw): + """Add a new header tuple to the list. + + Keyword arguments can specify additional parameters for the header + value, with underscores converted to dashes:: + + >>> d = Headers() + >>> d.add('Content-Type', 'text/plain') + >>> d.add('Content-Disposition', 'attachment', filename='foo.png') + + The keyword argument dumping uses :func:`dump_options_header` + behind the scenes. + + .. versionadded:: 0.4.1 + keyword arguments were added for :mod:`wsgiref` compatibility. + """ + if kw: + _value = _options_header_vkw(_value, kw) + _value = _str_header_value(_value) + self._list.append((_key, _value)) + + def add_header(self, _key, _value, **_kw): + """Add a new header tuple to the list. + + An alias for :meth:`add` for compatibility with the :mod:`wsgiref` + :meth:`~wsgiref.headers.Headers.add_header` method. + """ + self.add(_key, _value, **_kw) + + def clear(self): + """Clears all headers.""" + del self._list[:] + + def set(self, _key, _value, **kw): + """Remove all header tuples for `key` and add a new one. The newly + added key either appears at the end of the list if there was no + entry or replaces the first one. + + Keyword arguments can specify additional parameters for the header + value, with underscores converted to dashes. See :meth:`add` for + more information. + + .. versionchanged:: 0.6.1 + :meth:`set` now accepts the same arguments as :meth:`add`. + + :param key: The key to be inserted. + :param value: The value to be inserted. + """ + if kw: + _value = _options_header_vkw(_value, kw) + _value = _str_header_value(_value) + if not self._list: + self._list.append((_key, _value)) + return + listiter = iter(self._list) + ikey = _key.lower() + for idx, (old_key, _old_value) in enumerate(listiter): + if old_key.lower() == ikey: + # replace first occurrence + self._list[idx] = (_key, _value) + break + else: + self._list.append((_key, _value)) + return + self._list[idx + 1 :] = [t for t in listiter if t[0].lower() != ikey] + + def setlist(self, key, values): + """Remove any existing values for a header and add new ones. + + :param key: The header key to set. + :param values: An iterable of values to set for the key. + + .. versionadded:: 1.0 + """ + if values: + values_iter = iter(values) + self.set(key, next(values_iter)) + + for value in values_iter: + self.add(key, value) + else: + self.remove(key) + + def setdefault(self, key, default): + """Return the first value for the key if it is in the headers, + otherwise set the header to the value given by ``default`` and + return that. + + :param key: The header key to get. + :param default: The value to set for the key if it is not in the + headers. + """ + if key in self: + return self[key] + + self.set(key, default) + return default + + def setlistdefault(self, key, default): + """Return the list of values for the key if it is in the + headers, otherwise set the header to the list of values given + by ``default`` and return that. + + Unlike :meth:`MultiDict.setlistdefault`, modifying the returned + list will not affect the headers. + + :param key: The header key to get. + :param default: An iterable of values to set for the key if it + is not in the headers. + + .. versionadded:: 1.0 + """ + if key not in self: + self.setlist(key, default) + + return self.getlist(key) + + def __setitem__(self, key, value): + """Like :meth:`set` but also supports index/slice based setting.""" + if isinstance(key, (slice, int)): + if isinstance(key, int): + value = [value] + value = [(k, _str_header_value(v)) for (k, v) in value] + if isinstance(key, int): + self._list[key] = value[0] + else: + self._list[key] = value + else: + self.set(key, value) + + def update(self, *args, **kwargs): + """Replace headers in this object with items from another + headers object and keyword arguments. + + To extend existing keys instead of replacing, use :meth:`extend` + instead. + + If provided, the first argument can be another :class:`Headers` + object, a :class:`MultiDict`, :class:`dict`, or iterable of + pairs. + + .. versionadded:: 1.0 + """ + if len(args) > 1: + raise TypeError(f"update expected at most 1 arguments, got {len(args)}") + + if args: + mapping = args[0] + + if isinstance(mapping, (Headers, MultiDict)): + for key in mapping.keys(): + self.setlist(key, mapping.getlist(key)) + elif isinstance(mapping, dict): + for key, value in mapping.items(): + if isinstance(value, (list, tuple)): + self.setlist(key, value) + else: + self.set(key, value) + else: + for key, value in mapping: + self.set(key, value) + + for key, value in kwargs.items(): + if isinstance(value, (list, tuple)): + self.setlist(key, value) + else: + self.set(key, value) + + def to_wsgi_list(self): + """Convert the headers into a list suitable for WSGI. + + :return: list + """ + return list(self) + + def copy(self): + return self.__class__(self._list) + + def __copy__(self): + return self.copy() + + def __str__(self): + """Returns formatted headers suitable for HTTP transmission.""" + strs = [] + for key, value in self.to_wsgi_list(): + strs.append(f"{key}: {value}") + strs.append("\r\n") + return "\r\n".join(strs) + + def __repr__(self): + return f"{type(self).__name__}({list(self)!r})" + + +def _options_header_vkw(value: str, kw: dict[str, t.Any]): + return http.dump_options_header( + value, {k.replace("_", "-"): v for k, v in kw.items()} + ) + + +_newline_re = re.compile(r"[\r\n]") + + +def _str_header_value(value: t.Any) -> str: + if not isinstance(value, str): + value = str(value) + + if _newline_re.search(value) is not None: + raise ValueError("Header values must not contain newline characters.") + + return value + + +class EnvironHeaders(ImmutableHeadersMixin, Headers): + """Read only version of the headers from a WSGI environment. This + provides the same interface as `Headers` and is constructed from + a WSGI environment. + From Werkzeug 0.3 onwards, the `KeyError` raised by this class is also a + subclass of the :exc:`~exceptions.BadRequest` HTTP exception and will + render a page for a ``400 BAD REQUEST`` if caught in a catch-all for + HTTP exceptions. + """ + + def __init__(self, environ): + self.environ = environ + + def __eq__(self, other): + return self.environ is other.environ + + __hash__ = None + + def __getitem__(self, key, _get_mode=False): + # _get_mode is a no-op for this class as there is no index but + # used because get() calls it. + if not isinstance(key, str): + raise KeyError(key) + key = key.upper().replace("-", "_") + if key in {"CONTENT_TYPE", "CONTENT_LENGTH"}: + return self.environ[key] + return self.environ[f"HTTP_{key}"] + + def __len__(self): + # the iter is necessary because otherwise list calls our + # len which would call list again and so forth. + return len(list(iter(self))) + + def __iter__(self): + for key, value in self.environ.items(): + if key.startswith("HTTP_") and key not in { + "HTTP_CONTENT_TYPE", + "HTTP_CONTENT_LENGTH", + }: + yield key[5:].replace("_", "-").title(), value + elif key in {"CONTENT_TYPE", "CONTENT_LENGTH"} and value: + yield key.replace("_", "-").title(), value + + def copy(self): + raise TypeError(f"cannot create {type(self).__name__!r} copies") + + +# circular dependencies +from .. import http diff --git a/src/werkzeug/datastructures/headers.pyi b/src/werkzeug/datastructures/headers.pyi new file mode 100644 index 0000000..8650222 --- /dev/null +++ b/src/werkzeug/datastructures/headers.pyi @@ -0,0 +1,109 @@ +from collections.abc import Callable +from collections.abc import Iterable +from collections.abc import Iterator +from collections.abc import Mapping +from typing import Literal +from typing import NoReturn +from typing import overload +from typing import TypeVar + +from _typeshed import SupportsKeysAndGetItem +from _typeshed.wsgi import WSGIEnvironment + +from .mixins import ImmutableHeadersMixin + +D = TypeVar("D") +T = TypeVar("T") + +class Headers(dict[str, str]): + _list: list[tuple[str, str]] + def __init__( + self, + defaults: Mapping[str, str | Iterable[str]] + | Iterable[tuple[str, str]] + | None = None, + ) -> None: ... + @overload + def __getitem__(self, key: str) -> str: ... + @overload + def __getitem__(self, key: int) -> tuple[str, str]: ... + @overload + def __getitem__(self, key: slice) -> Headers: ... + @overload + def __getitem__(self, key: str, _get_mode: Literal[True] = ...) -> str: ... + def __eq__(self, other: object) -> bool: ... + @overload # type: ignore + def get(self, key: str, default: str) -> str: ... + @overload + def get(self, key: str, default: str | None = None) -> str | None: ... + @overload + def get( + self, key: str, default: T | None = None, type: Callable[[str], T] = ... + ) -> T | None: ... + @overload + def getlist(self, key: str) -> list[str]: ... + @overload + def getlist(self, key: str, type: Callable[[str], T]) -> list[T]: ... + def get_all(self, name: str) -> list[str]: ... + def items( # type: ignore + self, lower: bool = False + ) -> Iterator[tuple[str, str]]: ... + def keys(self, lower: bool = False) -> Iterator[str]: ... # type: ignore + def values(self) -> Iterator[str]: ... # type: ignore + def extend( + self, + *args: Mapping[str, str | Iterable[str]] | Iterable[tuple[str, str]], + **kwargs: str | Iterable[str], + ) -> None: ... + @overload + def __delitem__(self, key: str | int | slice) -> None: ... + @overload + def __delitem__(self, key: str, _index_operation: Literal[False]) -> None: ... + def remove(self, key: str) -> None: ... + @overload # type: ignore + def pop(self, key: str, default: str | None = None) -> str: ... + @overload + def pop( + self, key: int | None = None, default: tuple[str, str] | None = None + ) -> tuple[str, str]: ... + def popitem(self) -> tuple[str, str]: ... + def __contains__(self, key: str) -> bool: ... # type: ignore + def has_key(self, key: str) -> bool: ... + def __iter__(self) -> Iterator[tuple[str, str]]: ... # type: ignore + def add(self, _key: str, _value: str, **kw: str) -> None: ... + def _validate_value(self, value: str) -> None: ... + def add_header(self, _key: str, _value: str, **_kw: str) -> None: ... + def clear(self) -> None: ... + def set(self, _key: str, _value: str, **kw: str) -> None: ... + def setlist(self, key: str, values: Iterable[str]) -> None: ... + def setdefault(self, key: str, default: str) -> str: ... + def setlistdefault(self, key: str, default: Iterable[str]) -> None: ... + @overload + def __setitem__(self, key: str, value: str) -> None: ... + @overload + def __setitem__(self, key: int, value: tuple[str, str]) -> None: ... + @overload + def __setitem__(self, key: slice, value: Iterable[tuple[str, str]]) -> None: ... + @overload + def update( + self, __m: SupportsKeysAndGetItem[str, str], **kwargs: str | Iterable[str] + ) -> None: ... + @overload + def update( + self, __m: Iterable[tuple[str, str]], **kwargs: str | Iterable[str] + ) -> None: ... + @overload + def update(self, **kwargs: str | Iterable[str]) -> None: ... + def to_wsgi_list(self) -> list[tuple[str, str]]: ... + def copy(self) -> Headers: ... + def __copy__(self) -> Headers: ... + +class EnvironHeaders(ImmutableHeadersMixin, Headers): + environ: WSGIEnvironment + def __init__(self, environ: WSGIEnvironment) -> None: ... + def __eq__(self, other: object) -> bool: ... + def __getitem__( # type: ignore + self, key: str, _get_mode: Literal[False] = False + ) -> str: ... + def __iter__(self) -> Iterator[tuple[str, str]]: ... # type: ignore + def copy(self) -> NoReturn: ... diff --git a/src/werkzeug/datastructures/mixins.py b/src/werkzeug/datastructures/mixins.py new file mode 100644 index 0000000..2c84ca8 --- /dev/null +++ b/src/werkzeug/datastructures/mixins.py @@ -0,0 +1,242 @@ +from __future__ import annotations + +from itertools import repeat + +from .._internal import _missing + + +def is_immutable(self): + raise TypeError(f"{type(self).__name__!r} objects are immutable") + + +class ImmutableListMixin: + """Makes a :class:`list` immutable. + + .. versionadded:: 0.5 + + :private: + """ + + _hash_cache = None + + def __hash__(self): + if self._hash_cache is not None: + return self._hash_cache + rv = self._hash_cache = hash(tuple(self)) + return rv + + def __reduce_ex__(self, protocol): + return type(self), (list(self),) + + def __delitem__(self, key): + is_immutable(self) + + def __iadd__(self, other): + is_immutable(self) + + def __imul__(self, other): + is_immutable(self) + + def __setitem__(self, key, value): + is_immutable(self) + + def append(self, item): + is_immutable(self) + + def remove(self, item): + is_immutable(self) + + def extend(self, iterable): + is_immutable(self) + + def insert(self, pos, value): + is_immutable(self) + + def pop(self, index=-1): + is_immutable(self) + + def reverse(self): + is_immutable(self) + + def sort(self, key=None, reverse=False): + is_immutable(self) + + +class ImmutableDictMixin: + """Makes a :class:`dict` immutable. + + .. versionadded:: 0.5 + + :private: + """ + + _hash_cache = None + + @classmethod + def fromkeys(cls, keys, value=None): + instance = super().__new__(cls) + instance.__init__(zip(keys, repeat(value))) + return instance + + def __reduce_ex__(self, protocol): + return type(self), (dict(self),) + + def _iter_hashitems(self): + return self.items() + + def __hash__(self): + if self._hash_cache is not None: + return self._hash_cache + rv = self._hash_cache = hash(frozenset(self._iter_hashitems())) + return rv + + def setdefault(self, key, default=None): + is_immutable(self) + + def update(self, *args, **kwargs): + is_immutable(self) + + def pop(self, key, default=None): + is_immutable(self) + + def popitem(self): + is_immutable(self) + + def __setitem__(self, key, value): + is_immutable(self) + + def __delitem__(self, key): + is_immutable(self) + + def clear(self): + is_immutable(self) + + +class ImmutableMultiDictMixin(ImmutableDictMixin): + """Makes a :class:`MultiDict` immutable. + + .. versionadded:: 0.5 + + :private: + """ + + def __reduce_ex__(self, protocol): + return type(self), (list(self.items(multi=True)),) + + def _iter_hashitems(self): + return self.items(multi=True) + + def add(self, key, value): + is_immutable(self) + + def popitemlist(self): + is_immutable(self) + + def poplist(self, key): + is_immutable(self) + + def setlist(self, key, new_list): + is_immutable(self) + + def setlistdefault(self, key, default_list=None): + is_immutable(self) + + +class ImmutableHeadersMixin: + """Makes a :class:`Headers` immutable. We do not mark them as + hashable though since the only usecase for this datastructure + in Werkzeug is a view on a mutable structure. + + .. versionadded:: 0.5 + + :private: + """ + + def __delitem__(self, key, **kwargs): + is_immutable(self) + + def __setitem__(self, key, value): + is_immutable(self) + + def set(self, _key, _value, **kwargs): + is_immutable(self) + + def setlist(self, key, values): + is_immutable(self) + + def add(self, _key, _value, **kwargs): + is_immutable(self) + + def add_header(self, _key, _value, **_kwargs): + is_immutable(self) + + def remove(self, key): + is_immutable(self) + + def extend(self, *args, **kwargs): + is_immutable(self) + + def update(self, *args, **kwargs): + is_immutable(self) + + def insert(self, pos, value): + is_immutable(self) + + def pop(self, key=None, default=_missing): + is_immutable(self) + + def popitem(self): + is_immutable(self) + + def setdefault(self, key, default): + is_immutable(self) + + def setlistdefault(self, key, default): + is_immutable(self) + + +def _calls_update(name): + def oncall(self, *args, **kw): + rv = getattr(super(UpdateDictMixin, self), name)(*args, **kw) + + if self.on_update is not None: + self.on_update(self) + + return rv + + oncall.__name__ = name + return oncall + + +class UpdateDictMixin(dict): + """Makes dicts call `self.on_update` on modifications. + + .. versionadded:: 0.5 + + :private: + """ + + on_update = None + + def setdefault(self, key, default=None): + modified = key not in self + rv = super().setdefault(key, default) + if modified and self.on_update is not None: + self.on_update(self) + return rv + + def pop(self, key, default=_missing): + modified = key in self + if default is _missing: + rv = super().pop(key) + else: + rv = super().pop(key, default) + if modified and self.on_update is not None: + self.on_update(self) + return rv + + __setitem__ = _calls_update("__setitem__") + __delitem__ = _calls_update("__delitem__") + clear = _calls_update("clear") + popitem = _calls_update("popitem") + update = _calls_update("update") diff --git a/src/werkzeug/datastructures/mixins.pyi b/src/werkzeug/datastructures/mixins.pyi new file mode 100644 index 0000000..74ed4b8 --- /dev/null +++ b/src/werkzeug/datastructures/mixins.pyi @@ -0,0 +1,97 @@ +from collections.abc import Callable +from collections.abc import Hashable +from collections.abc import Iterable +from typing import Any +from typing import NoReturn +from typing import overload +from typing import SupportsIndex +from typing import TypeVar + +from _typeshed import SupportsKeysAndGetItem + +from .headers import Headers + +K = TypeVar("K") +T = TypeVar("T") +V = TypeVar("V") + +def is_immutable(self: object) -> NoReturn: ... + +class ImmutableListMixin(list[V]): + _hash_cache: int | None + def __hash__(self) -> int: ... # type: ignore + def __delitem__(self, key: SupportsIndex | slice) -> NoReturn: ... + def __iadd__(self, other: t.Any) -> NoReturn: ... # type: ignore + def __imul__(self, other: SupportsIndex) -> NoReturn: ... + def __setitem__(self, key: int | slice, value: V) -> NoReturn: ... # type: ignore + def append(self, value: V) -> NoReturn: ... + def remove(self, value: V) -> NoReturn: ... + def extend(self, values: Iterable[V]) -> NoReturn: ... + def insert(self, pos: SupportsIndex, value: V) -> NoReturn: ... + def pop(self, index: SupportsIndex = -1) -> NoReturn: ... + def reverse(self) -> NoReturn: ... + def sort( + self, key: Callable[[V], Any] | None = None, reverse: bool = False + ) -> NoReturn: ... + +class ImmutableDictMixin(dict[K, V]): + _hash_cache: int | None + @classmethod + def fromkeys( # type: ignore + cls, keys: Iterable[K], value: V | None = None + ) -> ImmutableDictMixin[K, V]: ... + def _iter_hashitems(self) -> Iterable[Hashable]: ... + def __hash__(self) -> int: ... # type: ignore + def setdefault(self, key: K, default: V | None = None) -> NoReturn: ... + def update(self, *args: Any, **kwargs: V) -> NoReturn: ... + def pop(self, key: K, default: V | None = None) -> NoReturn: ... # type: ignore + def popitem(self) -> NoReturn: ... + def __setitem__(self, key: K, value: V) -> NoReturn: ... + def __delitem__(self, key: K) -> NoReturn: ... + def clear(self) -> NoReturn: ... + +class ImmutableMultiDictMixin(ImmutableDictMixin[K, V]): + def _iter_hashitems(self) -> Iterable[Hashable]: ... + def add(self, key: K, value: V) -> NoReturn: ... + def popitemlist(self) -> NoReturn: ... + def poplist(self, key: K) -> NoReturn: ... + def setlist(self, key: K, new_list: Iterable[V]) -> NoReturn: ... + def setlistdefault( + self, key: K, default_list: Iterable[V] | None = None + ) -> NoReturn: ... + +class ImmutableHeadersMixin(Headers): + def __delitem__(self, key: Any, _index_operation: bool = True) -> NoReturn: ... + def __setitem__(self, key: Any, value: Any) -> NoReturn: ... + def set(self, _key: Any, _value: Any, **kw: Any) -> NoReturn: ... + def setlist(self, key: Any, values: Any) -> NoReturn: ... + def add(self, _key: Any, _value: Any, **kw: Any) -> NoReturn: ... + def add_header(self, _key: Any, _value: Any, **_kw: Any) -> NoReturn: ... + def remove(self, key: Any) -> NoReturn: ... + def extend(self, *args: Any, **kwargs: Any) -> NoReturn: ... + def update(self, *args: Any, **kwargs: Any) -> NoReturn: ... + def insert(self, pos: Any, value: Any) -> NoReturn: ... + def pop(self, key: Any = None, default: Any = ...) -> NoReturn: ... + def popitem(self) -> NoReturn: ... + def setdefault(self, key: Any, default: Any) -> NoReturn: ... + def setlistdefault(self, key: Any, default: Any) -> NoReturn: ... + +def _calls_update(name: str) -> Callable[[UpdateDictMixin[K, V]], Any]: ... + +class UpdateDictMixin(dict[K, V]): + on_update: Callable[[UpdateDictMixin[K, V] | None, None], None] + def setdefault(self, key: K, default: V | None = None) -> V: ... + @overload + def pop(self, key: K) -> V: ... + @overload + def pop(self, key: K, default: V | T = ...) -> V | T: ... + def __setitem__(self, key: K, value: V) -> None: ... + def __delitem__(self, key: K) -> None: ... + def clear(self) -> None: ... + def popitem(self) -> tuple[K, V]: ... + @overload + def update(self, __m: SupportsKeysAndGetItem[K, V], **kwargs: V) -> None: ... + @overload + def update(self, __m: Iterable[tuple[K, V]], **kwargs: V) -> None: ... + @overload + def update(self, **kwargs: V) -> None: ... diff --git a/src/werkzeug/datastructures/range.py b/src/werkzeug/datastructures/range.py new file mode 100644 index 0000000..7011ea4 --- /dev/null +++ b/src/werkzeug/datastructures/range.py @@ -0,0 +1,180 @@ +from __future__ import annotations + + +class IfRange: + """Very simple object that represents the `If-Range` header in parsed + form. It will either have neither a etag or date or one of either but + never both. + + .. versionadded:: 0.7 + """ + + def __init__(self, etag=None, date=None): + #: The etag parsed and unquoted. Ranges always operate on strong + #: etags so the weakness information is not necessary. + self.etag = etag + #: The date in parsed format or `None`. + self.date = date + + def to_header(self): + """Converts the object back into an HTTP header.""" + if self.date is not None: + return http.http_date(self.date) + if self.etag is not None: + return http.quote_etag(self.etag) + return "" + + def __str__(self): + return self.to_header() + + def __repr__(self): + return f"<{type(self).__name__} {str(self)!r}>" + + +class Range: + """Represents a ``Range`` header. All methods only support only + bytes as the unit. Stores a list of ranges if given, but the methods + only work if only one range is provided. + + :raise ValueError: If the ranges provided are invalid. + + .. versionchanged:: 0.15 + The ranges passed in are validated. + + .. versionadded:: 0.7 + """ + + def __init__(self, units, ranges): + #: The units of this range. Usually "bytes". + self.units = units + #: A list of ``(begin, end)`` tuples for the range header provided. + #: The ranges are non-inclusive. + self.ranges = ranges + + for start, end in ranges: + if start is None or (end is not None and (start < 0 or start >= end)): + raise ValueError(f"{(start, end)} is not a valid range.") + + def range_for_length(self, length): + """If the range is for bytes, the length is not None and there is + exactly one range and it is satisfiable it returns a ``(start, stop)`` + tuple, otherwise `None`. + """ + if self.units != "bytes" or length is None or len(self.ranges) != 1: + return None + start, end = self.ranges[0] + if end is None: + end = length + if start < 0: + start += length + if http.is_byte_range_valid(start, end, length): + return start, min(end, length) + return None + + def make_content_range(self, length): + """Creates a :class:`~werkzeug.datastructures.ContentRange` object + from the current range and given content length. + """ + rng = self.range_for_length(length) + if rng is not None: + return ContentRange(self.units, rng[0], rng[1], length) + return None + + def to_header(self): + """Converts the object back into an HTTP header.""" + ranges = [] + for begin, end in self.ranges: + if end is None: + ranges.append(f"{begin}-" if begin >= 0 else str(begin)) + else: + ranges.append(f"{begin}-{end - 1}") + return f"{self.units}={','.join(ranges)}" + + def to_content_range_header(self, length): + """Converts the object into `Content-Range` HTTP header, + based on given length + """ + range = self.range_for_length(length) + if range is not None: + return f"{self.units} {range[0]}-{range[1] - 1}/{length}" + return None + + def __str__(self): + return self.to_header() + + def __repr__(self): + return f"<{type(self).__name__} {str(self)!r}>" + + +def _callback_property(name): + def fget(self): + return getattr(self, name) + + def fset(self, value): + setattr(self, name, value) + if self.on_update is not None: + self.on_update(self) + + return property(fget, fset) + + +class ContentRange: + """Represents the content range header. + + .. versionadded:: 0.7 + """ + + def __init__(self, units, start, stop, length=None, on_update=None): + assert http.is_byte_range_valid(start, stop, length), "Bad range provided" + self.on_update = on_update + self.set(start, stop, length, units) + + #: The units to use, usually "bytes" + units = _callback_property("_units") + #: The start point of the range or `None`. + start = _callback_property("_start") + #: The stop point of the range (non-inclusive) or `None`. Can only be + #: `None` if also start is `None`. + stop = _callback_property("_stop") + #: The length of the range or `None`. + length = _callback_property("_length") + + def set(self, start, stop, length=None, units="bytes"): + """Simple method to update the ranges.""" + assert http.is_byte_range_valid(start, stop, length), "Bad range provided" + self._units = units + self._start = start + self._stop = stop + self._length = length + if self.on_update is not None: + self.on_update(self) + + def unset(self): + """Sets the units to `None` which indicates that the header should + no longer be used. + """ + self.set(None, None, units=None) + + def to_header(self): + if self.units is None: + return "" + if self.length is None: + length = "*" + else: + length = self.length + if self.start is None: + return f"{self.units} */{length}" + return f"{self.units} {self.start}-{self.stop - 1}/{length}" + + def __bool__(self): + return self.units is not None + + def __str__(self): + return self.to_header() + + def __repr__(self): + return f"<{type(self).__name__} {str(self)!r}>" + + +# circular dependencies +from .. import http diff --git a/src/werkzeug/datastructures/range.pyi b/src/werkzeug/datastructures/range.pyi new file mode 100644 index 0000000..f38ad69 --- /dev/null +++ b/src/werkzeug/datastructures/range.pyi @@ -0,0 +1,57 @@ +from collections.abc import Callable +from datetime import datetime + +class IfRange: + etag: str | None + date: datetime | None + def __init__( + self, etag: str | None = None, date: datetime | None = None + ) -> None: ... + def to_header(self) -> str: ... + +class Range: + units: str + ranges: list[tuple[int, int | None]] + def __init__(self, units: str, ranges: list[tuple[int, int | None]]) -> None: ... + def range_for_length(self, length: int | None) -> tuple[int, int] | None: ... + def make_content_range(self, length: int | None) -> ContentRange | None: ... + def to_header(self) -> str: ... + def to_content_range_header(self, length: int | None) -> str | None: ... + +def _callback_property(name: str) -> property: ... + +class ContentRange: + on_update: Callable[[ContentRange], None] | None + def __init__( + self, + units: str | None, + start: int | None, + stop: int | None, + length: int | None = None, + on_update: Callable[[ContentRange], None] | None = None, + ) -> None: ... + @property + def units(self) -> str | None: ... + @units.setter + def units(self, value: str | None) -> None: ... + @property + def start(self) -> int | None: ... + @start.setter + def start(self, value: int | None) -> None: ... + @property + def stop(self) -> int | None: ... + @stop.setter + def stop(self, value: int | None) -> None: ... + @property + def length(self) -> int | None: ... + @length.setter + def length(self, value: int | None) -> None: ... + def set( + self, + start: int | None, + stop: int | None, + length: int | None = None, + units: str | None = "bytes", + ) -> None: ... + def unset(self) -> None: ... + def to_header(self) -> str: ... diff --git a/src/werkzeug/datastructures/structures.py b/src/werkzeug/datastructures/structures.py new file mode 100644 index 0000000..7ea7bee --- /dev/null +++ b/src/werkzeug/datastructures/structures.py @@ -0,0 +1,1006 @@ +from __future__ import annotations + +from collections.abc import MutableSet +from copy import deepcopy + +from .. import exceptions +from .._internal import _missing +from .mixins import ImmutableDictMixin +from .mixins import ImmutableListMixin +from .mixins import ImmutableMultiDictMixin +from .mixins import UpdateDictMixin + + +def is_immutable(self): + raise TypeError(f"{type(self).__name__!r} objects are immutable") + + +def iter_multi_items(mapping): + """Iterates over the items of a mapping yielding keys and values + without dropping any from more complex structures. + """ + if isinstance(mapping, MultiDict): + yield from mapping.items(multi=True) + elif isinstance(mapping, dict): + for key, value in mapping.items(): + if isinstance(value, (tuple, list)): + for v in value: + yield key, v + else: + yield key, value + else: + yield from mapping + + +class ImmutableList(ImmutableListMixin, list): + """An immutable :class:`list`. + + .. versionadded:: 0.5 + + :private: + """ + + def __repr__(self): + return f"{type(self).__name__}({list.__repr__(self)})" + + +class TypeConversionDict(dict): + """Works like a regular dict but the :meth:`get` method can perform + type conversions. :class:`MultiDict` and :class:`CombinedMultiDict` + are subclasses of this class and provide the same feature. + + .. versionadded:: 0.5 + """ + + def get(self, key, default=None, type=None): + """Return the default value if the requested data doesn't exist. + If `type` is provided and is a callable it should convert the value, + return it or raise a :exc:`ValueError` if that is not possible. In + this case the function will return the default as if the value was not + found: + + >>> d = TypeConversionDict(foo='42', bar='blub') + >>> d.get('foo', type=int) + 42 + >>> d.get('bar', -1, type=int) + -1 + + :param key: The key to be looked up. + :param default: The default value to be returned if the key can't + be looked up. If not further specified `None` is + returned. + :param type: A callable that is used to cast the value in the + :class:`MultiDict`. If a :exc:`ValueError` is raised + by this callable the default value is returned. + """ + try: + rv = self[key] + except KeyError: + return default + if type is not None: + try: + rv = type(rv) + except ValueError: + rv = default + return rv + + +class ImmutableTypeConversionDict(ImmutableDictMixin, TypeConversionDict): + """Works like a :class:`TypeConversionDict` but does not support + modifications. + + .. versionadded:: 0.5 + """ + + def copy(self): + """Return a shallow mutable copy of this object. Keep in mind that + the standard library's :func:`copy` function is a no-op for this class + like for any other python immutable type (eg: :class:`tuple`). + """ + return TypeConversionDict(self) + + def __copy__(self): + return self + + +class MultiDict(TypeConversionDict): + """A :class:`MultiDict` is a dictionary subclass customized to deal with + multiple values for the same key which is for example used by the parsing + functions in the wrappers. This is necessary because some HTML form + elements pass multiple values for the same key. + + :class:`MultiDict` implements all standard dictionary methods. + Internally, it saves all values for a key as a list, but the standard dict + access methods will only return the first value for a key. If you want to + gain access to the other values, too, you have to use the `list` methods as + explained below. + + Basic Usage: + + >>> d = MultiDict([('a', 'b'), ('a', 'c')]) + >>> d + MultiDict([('a', 'b'), ('a', 'c')]) + >>> d['a'] + 'b' + >>> d.getlist('a') + ['b', 'c'] + >>> 'a' in d + True + + It behaves like a normal dict thus all dict functions will only return the + first value when multiple values for one key are found. + + From Werkzeug 0.3 onwards, the `KeyError` raised by this class is also a + subclass of the :exc:`~exceptions.BadRequest` HTTP exception and will + render a page for a ``400 BAD REQUEST`` if caught in a catch-all for HTTP + exceptions. + + A :class:`MultiDict` can be constructed from an iterable of + ``(key, value)`` tuples, a dict, a :class:`MultiDict` or from Werkzeug 0.2 + onwards some keyword parameters. + + :param mapping: the initial value for the :class:`MultiDict`. Either a + regular dict, an iterable of ``(key, value)`` tuples + or `None`. + """ + + def __init__(self, mapping=None): + if isinstance(mapping, MultiDict): + dict.__init__(self, ((k, l[:]) for k, l in mapping.lists())) + elif isinstance(mapping, dict): + tmp = {} + for key, value in mapping.items(): + if isinstance(value, (tuple, list)): + if len(value) == 0: + continue + value = list(value) + else: + value = [value] + tmp[key] = value + dict.__init__(self, tmp) + else: + tmp = {} + for key, value in mapping or (): + tmp.setdefault(key, []).append(value) + dict.__init__(self, tmp) + + def __getstate__(self): + return dict(self.lists()) + + def __setstate__(self, value): + dict.clear(self) + dict.update(self, value) + + def __iter__(self): + # Work around https://bugs.python.org/issue43246. + # (`return super().__iter__()` also works here, which makes this look + # even more like it should be a no-op, yet it isn't.) + return dict.__iter__(self) + + def __getitem__(self, key): + """Return the first data value for this key; + raises KeyError if not found. + + :param key: The key to be looked up. + :raise KeyError: if the key does not exist. + """ + + if key in self: + lst = dict.__getitem__(self, key) + if len(lst) > 0: + return lst[0] + raise exceptions.BadRequestKeyError(key) + + def __setitem__(self, key, value): + """Like :meth:`add` but removes an existing key first. + + :param key: the key for the value. + :param value: the value to set. + """ + dict.__setitem__(self, key, [value]) + + def add(self, key, value): + """Adds a new value for the key. + + .. versionadded:: 0.6 + + :param key: the key for the value. + :param value: the value to add. + """ + dict.setdefault(self, key, []).append(value) + + def getlist(self, key, type=None): + """Return the list of items for a given key. If that key is not in the + `MultiDict`, the return value will be an empty list. Just like `get`, + `getlist` accepts a `type` parameter. All items will be converted + with the callable defined there. + + :param key: The key to be looked up. + :param type: A callable that is used to cast the value in the + :class:`MultiDict`. If a :exc:`ValueError` is raised + by this callable the value will be removed from the list. + :return: a :class:`list` of all the values for the key. + """ + try: + rv = dict.__getitem__(self, key) + except KeyError: + return [] + if type is None: + return list(rv) + result = [] + for item in rv: + try: + result.append(type(item)) + except ValueError: + pass + return result + + def setlist(self, key, new_list): + """Remove the old values for a key and add new ones. Note that the list + you pass the values in will be shallow-copied before it is inserted in + the dictionary. + + >>> d = MultiDict() + >>> d.setlist('foo', ['1', '2']) + >>> d['foo'] + '1' + >>> d.getlist('foo') + ['1', '2'] + + :param key: The key for which the values are set. + :param new_list: An iterable with the new values for the key. Old values + are removed first. + """ + dict.__setitem__(self, key, list(new_list)) + + def setdefault(self, key, default=None): + """Returns the value for the key if it is in the dict, otherwise it + returns `default` and sets that value for `key`. + + :param key: The key to be looked up. + :param default: The default value to be returned if the key is not + in the dict. If not further specified it's `None`. + """ + if key not in self: + self[key] = default + else: + default = self[key] + return default + + def setlistdefault(self, key, default_list=None): + """Like `setdefault` but sets multiple values. The list returned + is not a copy, but the list that is actually used internally. This + means that you can put new values into the dict by appending items + to the list: + + >>> d = MultiDict({"foo": 1}) + >>> d.setlistdefault("foo").extend([2, 3]) + >>> d.getlist("foo") + [1, 2, 3] + + :param key: The key to be looked up. + :param default_list: An iterable of default values. It is either copied + (in case it was a list) or converted into a list + before returned. + :return: a :class:`list` + """ + if key not in self: + default_list = list(default_list or ()) + dict.__setitem__(self, key, default_list) + else: + default_list = dict.__getitem__(self, key) + return default_list + + def items(self, multi=False): + """Return an iterator of ``(key, value)`` pairs. + + :param multi: If set to `True` the iterator returned will have a pair + for each value of each key. Otherwise it will only + contain pairs for the first value of each key. + """ + for key, values in dict.items(self): + if multi: + for value in values: + yield key, value + else: + yield key, values[0] + + def lists(self): + """Return a iterator of ``(key, values)`` pairs, where values is the list + of all values associated with the key.""" + for key, values in dict.items(self): + yield key, list(values) + + def values(self): + """Returns an iterator of the first value on every key's value list.""" + for values in dict.values(self): + yield values[0] + + def listvalues(self): + """Return an iterator of all values associated with a key. Zipping + :meth:`keys` and this is the same as calling :meth:`lists`: + + >>> d = MultiDict({"foo": [1, 2, 3]}) + >>> zip(d.keys(), d.listvalues()) == d.lists() + True + """ + return dict.values(self) + + def copy(self): + """Return a shallow copy of this object.""" + return self.__class__(self) + + def deepcopy(self, memo=None): + """Return a deep copy of this object.""" + return self.__class__(deepcopy(self.to_dict(flat=False), memo)) + + def to_dict(self, flat=True): + """Return the contents as regular dict. If `flat` is `True` the + returned dict will only have the first item present, if `flat` is + `False` all values will be returned as lists. + + :param flat: If set to `False` the dict returned will have lists + with all the values in it. Otherwise it will only + contain the first value for each key. + :return: a :class:`dict` + """ + if flat: + return dict(self.items()) + return dict(self.lists()) + + def update(self, mapping): + """update() extends rather than replaces existing key lists: + + >>> a = MultiDict({'x': 1}) + >>> b = MultiDict({'x': 2, 'y': 3}) + >>> a.update(b) + >>> a + MultiDict([('y', 3), ('x', 1), ('x', 2)]) + + If the value list for a key in ``other_dict`` is empty, no new values + will be added to the dict and the key will not be created: + + >>> x = {'empty_list': []} + >>> y = MultiDict() + >>> y.update(x) + >>> y + MultiDict([]) + """ + for key, value in iter_multi_items(mapping): + MultiDict.add(self, key, value) + + def pop(self, key, default=_missing): + """Pop the first item for a list on the dict. Afterwards the + key is removed from the dict, so additional values are discarded: + + >>> d = MultiDict({"foo": [1, 2, 3]}) + >>> d.pop("foo") + 1 + >>> "foo" in d + False + + :param key: the key to pop. + :param default: if provided the value to return if the key was + not in the dictionary. + """ + try: + lst = dict.pop(self, key) + + if len(lst) == 0: + raise exceptions.BadRequestKeyError(key) + + return lst[0] + except KeyError: + if default is not _missing: + return default + + raise exceptions.BadRequestKeyError(key) from None + + def popitem(self): + """Pop an item from the dict.""" + try: + item = dict.popitem(self) + + if len(item[1]) == 0: + raise exceptions.BadRequestKeyError(item[0]) + + return (item[0], item[1][0]) + except KeyError as e: + raise exceptions.BadRequestKeyError(e.args[0]) from None + + def poplist(self, key): + """Pop the list for a key from the dict. If the key is not in the dict + an empty list is returned. + + .. versionchanged:: 0.5 + If the key does no longer exist a list is returned instead of + raising an error. + """ + return dict.pop(self, key, []) + + def popitemlist(self): + """Pop a ``(key, list)`` tuple from the dict.""" + try: + return dict.popitem(self) + except KeyError as e: + raise exceptions.BadRequestKeyError(e.args[0]) from None + + def __copy__(self): + return self.copy() + + def __deepcopy__(self, memo): + return self.deepcopy(memo=memo) + + def __repr__(self): + return f"{type(self).__name__}({list(self.items(multi=True))!r})" + + +class _omd_bucket: + """Wraps values in the :class:`OrderedMultiDict`. This makes it + possible to keep an order over multiple different keys. It requires + a lot of extra memory and slows down access a lot, but makes it + possible to access elements in O(1) and iterate in O(n). + """ + + __slots__ = ("prev", "key", "value", "next") + + def __init__(self, omd, key, value): + self.prev = omd._last_bucket + self.key = key + self.value = value + self.next = None + + if omd._first_bucket is None: + omd._first_bucket = self + if omd._last_bucket is not None: + omd._last_bucket.next = self + omd._last_bucket = self + + def unlink(self, omd): + if self.prev: + self.prev.next = self.next + if self.next: + self.next.prev = self.prev + if omd._first_bucket is self: + omd._first_bucket = self.next + if omd._last_bucket is self: + omd._last_bucket = self.prev + + +class OrderedMultiDict(MultiDict): + """Works like a regular :class:`MultiDict` but preserves the + order of the fields. To convert the ordered multi dict into a + list you can use the :meth:`items` method and pass it ``multi=True``. + + In general an :class:`OrderedMultiDict` is an order of magnitude + slower than a :class:`MultiDict`. + + .. admonition:: note + + Due to a limitation in Python you cannot convert an ordered + multi dict into a regular dict by using ``dict(multidict)``. + Instead you have to use the :meth:`to_dict` method, otherwise + the internal bucket objects are exposed. + """ + + def __init__(self, mapping=None): + dict.__init__(self) + self._first_bucket = self._last_bucket = None + if mapping is not None: + OrderedMultiDict.update(self, mapping) + + def __eq__(self, other): + if not isinstance(other, MultiDict): + return NotImplemented + if isinstance(other, OrderedMultiDict): + iter1 = iter(self.items(multi=True)) + iter2 = iter(other.items(multi=True)) + try: + for k1, v1 in iter1: + k2, v2 = next(iter2) + if k1 != k2 or v1 != v2: + return False + except StopIteration: + return False + try: + next(iter2) + except StopIteration: + return True + return False + if len(self) != len(other): + return False + for key, values in self.lists(): + if other.getlist(key) != values: + return False + return True + + __hash__ = None + + def __reduce_ex__(self, protocol): + return type(self), (list(self.items(multi=True)),) + + def __getstate__(self): + return list(self.items(multi=True)) + + def __setstate__(self, values): + dict.clear(self) + for key, value in values: + self.add(key, value) + + def __getitem__(self, key): + if key in self: + return dict.__getitem__(self, key)[0].value + raise exceptions.BadRequestKeyError(key) + + def __setitem__(self, key, value): + self.poplist(key) + self.add(key, value) + + def __delitem__(self, key): + self.pop(key) + + def keys(self): + return (key for key, value in self.items()) + + def __iter__(self): + return iter(self.keys()) + + def values(self): + return (value for key, value in self.items()) + + def items(self, multi=False): + ptr = self._first_bucket + if multi: + while ptr is not None: + yield ptr.key, ptr.value + ptr = ptr.next + else: + returned_keys = set() + while ptr is not None: + if ptr.key not in returned_keys: + returned_keys.add(ptr.key) + yield ptr.key, ptr.value + ptr = ptr.next + + def lists(self): + returned_keys = set() + ptr = self._first_bucket + while ptr is not None: + if ptr.key not in returned_keys: + yield ptr.key, self.getlist(ptr.key) + returned_keys.add(ptr.key) + ptr = ptr.next + + def listvalues(self): + for _key, values in self.lists(): + yield values + + def add(self, key, value): + dict.setdefault(self, key, []).append(_omd_bucket(self, key, value)) + + def getlist(self, key, type=None): + try: + rv = dict.__getitem__(self, key) + except KeyError: + return [] + if type is None: + return [x.value for x in rv] + result = [] + for item in rv: + try: + result.append(type(item.value)) + except ValueError: + pass + return result + + def setlist(self, key, new_list): + self.poplist(key) + for value in new_list: + self.add(key, value) + + def setlistdefault(self, key, default_list=None): + raise TypeError("setlistdefault is unsupported for ordered multi dicts") + + def update(self, mapping): + for key, value in iter_multi_items(mapping): + OrderedMultiDict.add(self, key, value) + + def poplist(self, key): + buckets = dict.pop(self, key, ()) + for bucket in buckets: + bucket.unlink(self) + return [x.value for x in buckets] + + def pop(self, key, default=_missing): + try: + buckets = dict.pop(self, key) + except KeyError: + if default is not _missing: + return default + + raise exceptions.BadRequestKeyError(key) from None + + for bucket in buckets: + bucket.unlink(self) + + return buckets[0].value + + def popitem(self): + try: + key, buckets = dict.popitem(self) + except KeyError as e: + raise exceptions.BadRequestKeyError(e.args[0]) from None + + for bucket in buckets: + bucket.unlink(self) + + return key, buckets[0].value + + def popitemlist(self): + try: + key, buckets = dict.popitem(self) + except KeyError as e: + raise exceptions.BadRequestKeyError(e.args[0]) from None + + for bucket in buckets: + bucket.unlink(self) + + return key, [x.value for x in buckets] + + +class CombinedMultiDict(ImmutableMultiDictMixin, MultiDict): + """A read only :class:`MultiDict` that you can pass multiple :class:`MultiDict` + instances as sequence and it will combine the return values of all wrapped + dicts: + + >>> from werkzeug.datastructures import CombinedMultiDict, MultiDict + >>> post = MultiDict([('foo', 'bar')]) + >>> get = MultiDict([('blub', 'blah')]) + >>> combined = CombinedMultiDict([get, post]) + >>> combined['foo'] + 'bar' + >>> combined['blub'] + 'blah' + + This works for all read operations and will raise a `TypeError` for + methods that usually change data which isn't possible. + + From Werkzeug 0.3 onwards, the `KeyError` raised by this class is also a + subclass of the :exc:`~exceptions.BadRequest` HTTP exception and will + render a page for a ``400 BAD REQUEST`` if caught in a catch-all for HTTP + exceptions. + """ + + def __reduce_ex__(self, protocol): + return type(self), (self.dicts,) + + def __init__(self, dicts=None): + self.dicts = list(dicts) or [] + + @classmethod + def fromkeys(cls, keys, value=None): + raise TypeError(f"cannot create {cls.__name__!r} instances by fromkeys") + + def __getitem__(self, key): + for d in self.dicts: + if key in d: + return d[key] + raise exceptions.BadRequestKeyError(key) + + def get(self, key, default=None, type=None): + for d in self.dicts: + if key in d: + if type is not None: + try: + return type(d[key]) + except ValueError: + continue + return d[key] + return default + + def getlist(self, key, type=None): + rv = [] + for d in self.dicts: + rv.extend(d.getlist(key, type)) + return rv + + def _keys_impl(self): + """This function exists so __len__ can be implemented more efficiently, + saving one list creation from an iterator. + """ + rv = set() + rv.update(*self.dicts) + return rv + + def keys(self): + return self._keys_impl() + + def __iter__(self): + return iter(self.keys()) + + def items(self, multi=False): + found = set() + for d in self.dicts: + for key, value in d.items(multi): + if multi: + yield key, value + elif key not in found: + found.add(key) + yield key, value + + def values(self): + for _key, value in self.items(): + yield value + + def lists(self): + rv = {} + for d in self.dicts: + for key, values in d.lists(): + rv.setdefault(key, []).extend(values) + return list(rv.items()) + + def listvalues(self): + return (x[1] for x in self.lists()) + + def copy(self): + """Return a shallow mutable copy of this object. + + This returns a :class:`MultiDict` representing the data at the + time of copying. The copy will no longer reflect changes to the + wrapped dicts. + + .. versionchanged:: 0.15 + Return a mutable :class:`MultiDict`. + """ + return MultiDict(self) + + def to_dict(self, flat=True): + """Return the contents as regular dict. If `flat` is `True` the + returned dict will only have the first item present, if `flat` is + `False` all values will be returned as lists. + + :param flat: If set to `False` the dict returned will have lists + with all the values in it. Otherwise it will only + contain the first item for each key. + :return: a :class:`dict` + """ + if flat: + return dict(self.items()) + + return dict(self.lists()) + + def __len__(self): + return len(self._keys_impl()) + + def __contains__(self, key): + for d in self.dicts: + if key in d: + return True + return False + + def __repr__(self): + return f"{type(self).__name__}({self.dicts!r})" + + +class ImmutableDict(ImmutableDictMixin, dict): + """An immutable :class:`dict`. + + .. versionadded:: 0.5 + """ + + def __repr__(self): + return f"{type(self).__name__}({dict.__repr__(self)})" + + def copy(self): + """Return a shallow mutable copy of this object. Keep in mind that + the standard library's :func:`copy` function is a no-op for this class + like for any other python immutable type (eg: :class:`tuple`). + """ + return dict(self) + + def __copy__(self): + return self + + +class ImmutableMultiDict(ImmutableMultiDictMixin, MultiDict): + """An immutable :class:`MultiDict`. + + .. versionadded:: 0.5 + """ + + def copy(self): + """Return a shallow mutable copy of this object. Keep in mind that + the standard library's :func:`copy` function is a no-op for this class + like for any other python immutable type (eg: :class:`tuple`). + """ + return MultiDict(self) + + def __copy__(self): + return self + + +class ImmutableOrderedMultiDict(ImmutableMultiDictMixin, OrderedMultiDict): + """An immutable :class:`OrderedMultiDict`. + + .. versionadded:: 0.6 + """ + + def _iter_hashitems(self): + return enumerate(self.items(multi=True)) + + def copy(self): + """Return a shallow mutable copy of this object. Keep in mind that + the standard library's :func:`copy` function is a no-op for this class + like for any other python immutable type (eg: :class:`tuple`). + """ + return OrderedMultiDict(self) + + def __copy__(self): + return self + + +class CallbackDict(UpdateDictMixin, dict): + """A dict that calls a function passed every time something is changed. + The function is passed the dict instance. + """ + + def __init__(self, initial=None, on_update=None): + dict.__init__(self, initial or ()) + self.on_update = on_update + + def __repr__(self): + return f"<{type(self).__name__} {dict.__repr__(self)}>" + + +class HeaderSet(MutableSet): + """Similar to the :class:`ETags` class this implements a set-like structure. + Unlike :class:`ETags` this is case insensitive and used for vary, allow, and + content-language headers. + + If not constructed using the :func:`parse_set_header` function the + instantiation works like this: + + >>> hs = HeaderSet(['foo', 'bar', 'baz']) + >>> hs + HeaderSet(['foo', 'bar', 'baz']) + """ + + def __init__(self, headers=None, on_update=None): + self._headers = list(headers or ()) + self._set = {x.lower() for x in self._headers} + self.on_update = on_update + + def add(self, header): + """Add a new header to the set.""" + self.update((header,)) + + def remove(self, header): + """Remove a header from the set. This raises an :exc:`KeyError` if the + header is not in the set. + + .. versionchanged:: 0.5 + In older versions a :exc:`IndexError` was raised instead of a + :exc:`KeyError` if the object was missing. + + :param header: the header to be removed. + """ + key = header.lower() + if key not in self._set: + raise KeyError(header) + self._set.remove(key) + for idx, key in enumerate(self._headers): + if key.lower() == header: + del self._headers[idx] + break + if self.on_update is not None: + self.on_update(self) + + def update(self, iterable): + """Add all the headers from the iterable to the set. + + :param iterable: updates the set with the items from the iterable. + """ + inserted_any = False + for header in iterable: + key = header.lower() + if key not in self._set: + self._headers.append(header) + self._set.add(key) + inserted_any = True + if inserted_any and self.on_update is not None: + self.on_update(self) + + def discard(self, header): + """Like :meth:`remove` but ignores errors. + + :param header: the header to be discarded. + """ + try: + self.remove(header) + except KeyError: + pass + + def find(self, header): + """Return the index of the header in the set or return -1 if not found. + + :param header: the header to be looked up. + """ + header = header.lower() + for idx, item in enumerate(self._headers): + if item.lower() == header: + return idx + return -1 + + def index(self, header): + """Return the index of the header in the set or raise an + :exc:`IndexError`. + + :param header: the header to be looked up. + """ + rv = self.find(header) + if rv < 0: + raise IndexError(header) + return rv + + def clear(self): + """Clear the set.""" + self._set.clear() + del self._headers[:] + if self.on_update is not None: + self.on_update(self) + + def as_set(self, preserve_casing=False): + """Return the set as real python set type. When calling this, all + the items are converted to lowercase and the ordering is lost. + + :param preserve_casing: if set to `True` the items in the set returned + will have the original case like in the + :class:`HeaderSet`, otherwise they will + be lowercase. + """ + if preserve_casing: + return set(self._headers) + return set(self._set) + + def to_header(self): + """Convert the header set into an HTTP header string.""" + return ", ".join(map(http.quote_header_value, self._headers)) + + def __getitem__(self, idx): + return self._headers[idx] + + def __delitem__(self, idx): + rv = self._headers.pop(idx) + self._set.remove(rv.lower()) + if self.on_update is not None: + self.on_update(self) + + def __setitem__(self, idx, value): + old = self._headers[idx] + self._set.remove(old.lower()) + self._headers[idx] = value + self._set.add(value.lower()) + if self.on_update is not None: + self.on_update(self) + + def __contains__(self, header): + return header.lower() in self._set + + def __len__(self): + return len(self._set) + + def __iter__(self): + return iter(self._headers) + + def __bool__(self): + return bool(self._set) + + def __str__(self): + return self.to_header() + + def __repr__(self): + return f"{type(self).__name__}({self._headers!r})" + + +# circular dependencies +from .. import http diff --git a/src/werkzeug/datastructures/structures.pyi b/src/werkzeug/datastructures/structures.pyi new file mode 100644 index 0000000..2e7af35 --- /dev/null +++ b/src/werkzeug/datastructures/structures.pyi @@ -0,0 +1,208 @@ +from collections.abc import Callable +from collections.abc import Iterable +from collections.abc import Iterator +from collections.abc import Mapping +from typing import Any +from typing import Generic +from typing import Literal +from typing import NoReturn +from typing import overload +from typing import TypeVar + +from .mixins import ( + ImmutableDictMixin, + ImmutableListMixin, + ImmutableMultiDictMixin, + UpdateDictMixin, +) + +D = TypeVar("D") +K = TypeVar("K") +T = TypeVar("T") +V = TypeVar("V") +_CD = TypeVar("_CD", bound="CallbackDict") + +def is_immutable(self: object) -> NoReturn: ... +def iter_multi_items( + mapping: Mapping[K, V | Iterable[V]] | Iterable[tuple[K, V]] +) -> Iterator[tuple[K, V]]: ... + +class ImmutableList(ImmutableListMixin[V]): ... + +class TypeConversionDict(dict[K, V]): + @overload + def get(self, key: K, default: None = ..., type: None = ...) -> V | None: ... + @overload + def get(self, key: K, default: D, type: None = ...) -> D | V: ... + @overload + def get(self, key: K, default: D, type: Callable[[V], T]) -> D | T: ... + @overload + def get(self, key: K, type: Callable[[V], T]) -> T | None: ... + +class ImmutableTypeConversionDict(ImmutableDictMixin[K, V], TypeConversionDict[K, V]): + def copy(self) -> TypeConversionDict[K, V]: ... + def __copy__(self) -> ImmutableTypeConversionDict: ... + +class MultiDict(TypeConversionDict[K, V]): + def __init__( + self, + mapping: Mapping[K, Iterable[V] | V] | Iterable[tuple[K, V]] | None = None, + ) -> None: ... + def __getitem__(self, item: K) -> V: ... + def __setitem__(self, key: K, value: V) -> None: ... + def add(self, key: K, value: V) -> None: ... + @overload + def getlist(self, key: K) -> list[V]: ... + @overload + def getlist(self, key: K, type: Callable[[V], T] = ...) -> list[T]: ... + def setlist(self, key: K, new_list: Iterable[V]) -> None: ... + def setdefault(self, key: K, default: V | None = None) -> V: ... + def setlistdefault( + self, key: K, default_list: Iterable[V] | None = None + ) -> list[V]: ... + def items(self, multi: bool = False) -> Iterator[tuple[K, V]]: ... # type: ignore + def lists(self) -> Iterator[tuple[K, list[V]]]: ... + def values(self) -> Iterator[V]: ... # type: ignore + def listvalues(self) -> Iterator[list[V]]: ... + def copy(self) -> MultiDict[K, V]: ... + def deepcopy(self, memo: Any = None) -> MultiDict[K, V]: ... + @overload + def to_dict(self) -> dict[K, V]: ... + @overload + def to_dict(self, flat: Literal[False]) -> dict[K, list[V]]: ... + def update( # type: ignore + self, mapping: Mapping[K, Iterable[V] | V] | Iterable[tuple[K, V]] + ) -> None: ... + @overload + def pop(self, key: K) -> V: ... + @overload + def pop(self, key: K, default: V | T = ...) -> V | T: ... + def popitem(self) -> tuple[K, V]: ... + def poplist(self, key: K) -> list[V]: ... + def popitemlist(self) -> tuple[K, list[V]]: ... + def __copy__(self) -> MultiDict[K, V]: ... + def __deepcopy__(self, memo: Any) -> MultiDict[K, V]: ... + +class _omd_bucket(Generic[K, V]): + prev: _omd_bucket | None + next: _omd_bucket | None + key: K + value: V + def __init__(self, omd: OrderedMultiDict, key: K, value: V) -> None: ... + def unlink(self, omd: OrderedMultiDict) -> None: ... + +class OrderedMultiDict(MultiDict[K, V]): + _first_bucket: _omd_bucket | None + _last_bucket: _omd_bucket | None + def __init__(self, mapping: Mapping[K, V] | None = None) -> None: ... + def __eq__(self, other: object) -> bool: ... + def __getitem__(self, key: K) -> V: ... + def __setitem__(self, key: K, value: V) -> None: ... + def __delitem__(self, key: K) -> None: ... + def keys(self) -> Iterator[K]: ... # type: ignore + def __iter__(self) -> Iterator[K]: ... + def values(self) -> Iterator[V]: ... # type: ignore + def items(self, multi: bool = False) -> Iterator[tuple[K, V]]: ... # type: ignore + def lists(self) -> Iterator[tuple[K, list[V]]]: ... + def listvalues(self) -> Iterator[list[V]]: ... + def add(self, key: K, value: V) -> None: ... + @overload + def getlist(self, key: K) -> list[V]: ... + @overload + def getlist(self, key: K, type: Callable[[V], T] = ...) -> list[T]: ... + def setlist(self, key: K, new_list: Iterable[V]) -> None: ... + def setlistdefault( + self, key: K, default_list: Iterable[V] | None = None + ) -> list[V]: ... + def update( # type: ignore + self, mapping: Mapping[K, V] | Iterable[tuple[K, V]] + ) -> None: ... + def poplist(self, key: K) -> list[V]: ... + @overload + def pop(self, key: K) -> V: ... + @overload + def pop(self, key: K, default: V | T = ...) -> V | T: ... + def popitem(self) -> tuple[K, V]: ... + def popitemlist(self) -> tuple[K, list[V]]: ... + +class CombinedMultiDict(ImmutableMultiDictMixin[K, V], MultiDict[K, V]): # type: ignore + dicts: list[MultiDict[K, V]] + def __init__(self, dicts: Iterable[MultiDict[K, V]] | None) -> None: ... + @classmethod + def fromkeys(cls, keys: Any, value: Any = None) -> NoReturn: ... + def __getitem__(self, key: K) -> V: ... + @overload # type: ignore + def get(self, key: K) -> V | None: ... + @overload + def get(self, key: K, default: V | T = ...) -> V | T: ... + @overload + def get( + self, key: K, default: T | None = None, type: Callable[[V], T] = ... + ) -> T | None: ... + @overload + def getlist(self, key: K) -> list[V]: ... + @overload + def getlist(self, key: K, type: Callable[[V], T] = ...) -> list[T]: ... + def _keys_impl(self) -> set[K]: ... + def keys(self) -> set[K]: ... # type: ignore + def __iter__(self) -> set[K]: ... # type: ignore + def items(self, multi: bool = False) -> Iterator[tuple[K, V]]: ... # type: ignore + def values(self) -> Iterator[V]: ... # type: ignore + def lists(self) -> Iterator[tuple[K, list[V]]]: ... + def listvalues(self) -> Iterator[list[V]]: ... + def copy(self) -> MultiDict[K, V]: ... + @overload + def to_dict(self) -> dict[K, V]: ... + @overload + def to_dict(self, flat: Literal[False]) -> dict[K, list[V]]: ... + def __contains__(self, key: K) -> bool: ... # type: ignore + def has_key(self, key: K) -> bool: ... + +class ImmutableDict(ImmutableDictMixin[K, V], dict[K, V]): + def copy(self) -> dict[K, V]: ... + def __copy__(self) -> ImmutableDict[K, V]: ... + +class ImmutableMultiDict( # type: ignore + ImmutableMultiDictMixin[K, V], MultiDict[K, V] +): + def copy(self) -> MultiDict[K, V]: ... + def __copy__(self) -> ImmutableMultiDict[K, V]: ... + +class ImmutableOrderedMultiDict( # type: ignore + ImmutableMultiDictMixin[K, V], OrderedMultiDict[K, V] +): + def _iter_hashitems(self) -> Iterator[tuple[int, tuple[K, V]]]: ... + def copy(self) -> OrderedMultiDict[K, V]: ... + def __copy__(self) -> ImmutableOrderedMultiDict[K, V]: ... + +class CallbackDict(UpdateDictMixin[K, V], dict[K, V]): + def __init__( + self, + initial: Mapping[K, V] | Iterable[tuple[K, V]] | None = None, + on_update: Callable[[_CD], None] | None = None, + ) -> None: ... + +class HeaderSet(set[str]): + _headers: list[str] + _set: set[str] + on_update: Callable[[HeaderSet], None] | None + def __init__( + self, + headers: Iterable[str] | None = None, + on_update: Callable[[HeaderSet], None] | None = None, + ) -> None: ... + def add(self, header: str) -> None: ... + def remove(self, header: str) -> None: ... + def update(self, iterable: Iterable[str]) -> None: ... # type: ignore + def discard(self, header: str) -> None: ... + def find(self, header: str) -> int: ... + def index(self, header: str) -> int: ... + def clear(self) -> None: ... + def as_set(self, preserve_casing: bool = False) -> set[str]: ... + def to_header(self) -> str: ... + def __getitem__(self, idx: int) -> str: ... + def __delitem__(self, idx: int) -> None: ... + def __setitem__(self, idx: int, value: str) -> None: ... + def __contains__(self, header: str) -> bool: ... # type: ignore + def __len__(self) -> int: ... + def __iter__(self) -> Iterator[str]: ... diff --git a/src/werkzeug/debug/__init__.py b/src/werkzeug/debug/__init__.py index e0dcc65..3b04b53 100644 --- a/src/werkzeug/debug/__init__.py +++ b/src/werkzeug/debug/__init__.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import getpass import hashlib import json @@ -9,7 +11,6 @@ import time import typing as t import uuid from contextlib import ExitStack -from contextlib import nullcontext from io import BytesIO from itertools import chain from os.path import basename @@ -41,16 +42,16 @@ def hash_pin(pin: str) -> str: return hashlib.sha1(f"{pin} added salt".encode("utf-8", "replace")).hexdigest()[:12] -_machine_id: t.Optional[t.Union[str, bytes]] = None +_machine_id: str | bytes | None = None -def get_machine_id() -> t.Optional[t.Union[str, bytes]]: +def get_machine_id() -> str | bytes | None: global _machine_id if _machine_id is not None: return _machine_id - def _generate() -> t.Optional[t.Union[str, bytes]]: + def _generate() -> str | bytes | None: linux = b"" # machine-id is stable across boots, boot_id is not. @@ -104,7 +105,7 @@ def get_machine_id() -> t.Optional[t.Union[str, bytes]]: 0, winreg.KEY_READ | winreg.KEY_WOW64_64KEY, ) as rk: - guid: t.Union[str, bytes] + guid: str | bytes guid_type: int guid, guid_type = winreg.QueryValueEx(rk, "MachineGuid") @@ -126,7 +127,7 @@ class _ConsoleFrame: standalone console. """ - def __init__(self, namespace: t.Dict[str, t.Any]): + def __init__(self, namespace: dict[str, t.Any]): self.console = Console(namespace) self.id = 0 @@ -135,8 +136,8 @@ class _ConsoleFrame: def get_pin_and_cookie_name( - app: "WSGIApplication", -) -> t.Union[t.Tuple[str, str], t.Tuple[None, None]]: + app: WSGIApplication, +) -> tuple[str, str] | tuple[None, None]: """Given an application object this returns a semi-stable 9 digit pin code and a random key. The hope is that this is stable between restarts to not make debugging particularly frustrating. If the pin @@ -161,7 +162,7 @@ def get_pin_and_cookie_name( num = pin modname = getattr(app, "__module__", t.cast(object, app).__class__.__module__) - username: t.Optional[str] + username: str | None try: # getuser imports the pwd module, which does not exist in Google @@ -229,8 +230,8 @@ class DebuggedApplication: The ``evalex`` argument allows evaluating expressions in any frame of a traceback. This works by preserving each frame with its local - state. Some state, such as :doc:`local`, cannot be restored with the - frame by default. When ``evalex`` is enabled, + state. Some state, such as context globals, cannot be restored with + the frame by default. When ``evalex`` is enabled, ``environ["werkzeug.debug.preserve_context"]`` will be a callable that takes a context manager, and can be called multiple times. Each context manager will be entered before evaluating code in the @@ -262,11 +263,11 @@ class DebuggedApplication: def __init__( self, - app: "WSGIApplication", + app: WSGIApplication, evalex: bool = False, request_key: str = "werkzeug.request", console_path: str = "/console", - console_init_func: t.Optional[t.Callable[[], t.Dict[str, t.Any]]] = None, + console_init_func: t.Callable[[], dict[str, t.Any]] | None = None, show_hidden_frames: bool = False, pin_security: bool = True, pin_logging: bool = True, @@ -275,8 +276,8 @@ class DebuggedApplication: console_init_func = None self.app = app self.evalex = evalex - self.frames: t.Dict[int, t.Union[DebugFrameSummary, _ConsoleFrame]] = {} - self.frame_contexts: t.Dict[int, t.List[t.ContextManager[None]]] = {} + self.frames: dict[int, DebugFrameSummary | _ConsoleFrame] = {} + self.frame_contexts: dict[int, list[t.ContextManager[None]]] = {} self.request_key = request_key self.console_path = console_path self.console_init_func = console_init_func @@ -297,7 +298,7 @@ class DebuggedApplication: self.pin = None @property - def pin(self) -> t.Optional[str]: + def pin(self) -> str | None: if not hasattr(self, "_pin"): pin_cookie = get_pin_and_cookie_name(self.app) self._pin, self._pin_cookie = pin_cookie # type: ignore @@ -316,10 +317,10 @@ class DebuggedApplication: return self._pin_cookie def debug_application( - self, environ: "WSGIEnvironment", start_response: "StartResponse" + self, environ: WSGIEnvironment, start_response: StartResponse ) -> t.Iterator[bytes]: """Run the application and conserve the traceback frames.""" - contexts: t.List[t.ContextManager[t.Any]] = [] + contexts: list[t.ContextManager[t.Any]] = [] if self.evalex: environ["werkzeug.debug.preserve_context"] = contexts.append @@ -329,7 +330,7 @@ class DebuggedApplication: app_iter = self.app(environ, start_response) yield from app_iter if hasattr(app_iter, "close"): - app_iter.close() # type: ignore + app_iter.close() except Exception as e: if hasattr(app_iter, "close"): app_iter.close() # type: ignore @@ -367,7 +368,7 @@ class DebuggedApplication: self, request: Request, command: str, - frame: t.Union[DebugFrameSummary, _ConsoleFrame], + frame: DebugFrameSummary | _ConsoleFrame, ) -> Response: """Execute a command in a console.""" contexts = self.frame_contexts.get(id(frame), []) @@ -410,7 +411,7 @@ class DebuggedApplication: BytesIO(data), request.environ, download_name=filename, etag=etag ) - def check_pin_trust(self, environ: "WSGIEnvironment") -> t.Optional[bool]: + def check_pin_trust(self, environ: WSGIEnvironment) -> bool | None: """Checks if the request passed the pin test. This returns `True` if the request is trusted on a pin/cookie basis and returns `False` if not. Additionally if the cookie's stored pin hash is wrong it will return @@ -497,7 +498,7 @@ class DebuggedApplication: return Response("") def __call__( - self, environ: "WSGIEnvironment", start_response: "StartResponse" + self, environ: WSGIEnvironment, start_response: StartResponse ) -> t.Iterable[bytes]: """Dispatch the requests.""" # important: don't ever access a function here that reads the incoming diff --git a/src/werkzeug/debug/console.py b/src/werkzeug/debug/console.py index 69974d1..03ddc07 100644 --- a/src/werkzeug/debug/console.py +++ b/src/werkzeug/debug/console.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import code import sys import typing as t @@ -10,10 +12,7 @@ from .repr import debug_repr from .repr import dump from .repr import helper -if t.TYPE_CHECKING: - import codeop # noqa: F401 - -_stream: ContextVar["HTMLStringO"] = ContextVar("werkzeug.debug.console.stream") +_stream: ContextVar[HTMLStringO] = ContextVar("werkzeug.debug.console.stream") _ipy: ContextVar = ContextVar("werkzeug.debug.console.ipy") @@ -21,7 +20,7 @@ class HTMLStringO: """A StringO version that HTML escapes on write.""" def __init__(self) -> None: - self._buffer: t.List[str] = [] + self._buffer: list[str] = [] def isatty(self) -> bool: return False @@ -48,8 +47,6 @@ class HTMLStringO: return val def _write(self, x: str) -> None: - if isinstance(x, bytes): - x = x.decode("utf-8", "replace") self._buffer.append(x) def write(self, x: str) -> None: @@ -94,7 +91,7 @@ class ThreadedStream: def __setattr__(self, name: str, value: t.Any) -> None: raise AttributeError(f"read only attribute {name}") - def __dir__(self) -> t.List[str]: + def __dir__(self) -> list[str]: return dir(sys.__stdout__) def __getattribute__(self, name: str) -> t.Any: @@ -116,7 +113,7 @@ sys.displayhook = ThreadedStream.displayhook class _ConsoleLoader: def __init__(self) -> None: - self._storage: t.Dict[int, str] = {} + self._storage: dict[int, str] = {} def register(self, code: CodeType, source: str) -> None: self._storage[id(code)] = source @@ -125,7 +122,7 @@ class _ConsoleLoader: if isinstance(var, CodeType): self._storage[id(var)] = source - def get_source_by_code(self, code: CodeType) -> t.Optional[str]: + def get_source_by_code(self, code: CodeType) -> str | None: try: return self._storage[id(code)] except KeyError: @@ -133,9 +130,9 @@ class _ConsoleLoader: class _InteractiveConsole(code.InteractiveInterpreter): - locals: t.Dict[str, t.Any] + locals: dict[str, t.Any] - def __init__(self, globals: t.Dict[str, t.Any], locals: t.Dict[str, t.Any]) -> None: + def __init__(self, globals: dict[str, t.Any], locals: dict[str, t.Any]) -> None: self.loader = _ConsoleLoader() locals = { **globals, @@ -147,7 +144,7 @@ class _InteractiveConsole(code.InteractiveInterpreter): super().__init__(locals) original_compile = self.compile - def compile(source: str, filename: str, symbol: str) -> t.Optional[CodeType]: + def compile(source: str, filename: str, symbol: str) -> CodeType | None: code = original_compile(source, filename, symbol) if code is not None: @@ -157,7 +154,7 @@ class _InteractiveConsole(code.InteractiveInterpreter): self.compile = compile # type: ignore[assignment] self.more = False - self.buffer: t.List[str] = [] + self.buffer: list[str] = [] def runsource(self, source: str, **kwargs: t.Any) -> str: # type: ignore source = f"{source.rstrip()}\n" @@ -188,7 +185,7 @@ class _InteractiveConsole(code.InteractiveInterpreter): te = DebugTraceback(exc, skip=1) sys.stdout._write(te.render_traceback_html()) # type: ignore - def showsyntaxerror(self, filename: t.Optional[str] = None) -> None: + def showsyntaxerror(self, filename: str | None = None) -> None: from .tbtools import DebugTraceback exc = t.cast(BaseException, sys.exc_info()[1]) @@ -204,8 +201,8 @@ class Console: def __init__( self, - globals: t.Optional[t.Dict[str, t.Any]] = None, - locals: t.Optional[t.Dict[str, t.Any]] = None, + globals: dict[str, t.Any] | None = None, + locals: dict[str, t.Any] | None = None, ) -> None: if locals is None: locals = {} diff --git a/src/werkzeug/debug/repr.py b/src/werkzeug/debug/repr.py index c0872f1..3bf15a7 100644 --- a/src/werkzeug/debug/repr.py +++ b/src/werkzeug/debug/repr.py @@ -4,6 +4,8 @@ repr, these expose more information and produce HTML instead of ASCII. Together with the CSS and JavaScript of the debugger this gives a colorful and more compact output. """ +from __future__ import annotations + import codecs import re import sys @@ -57,7 +59,7 @@ class _Helper: def __repr__(self) -> str: return "Type help(object) for help about object." - def __call__(self, topic: t.Optional[t.Any] = None) -> None: + def __call__(self, topic: t.Any | None = None) -> None: if topic is None: sys.stdout._write(f"{self!r}") # type: ignore return @@ -65,8 +67,6 @@ class _Helper: pydoc.help(topic) rv = sys.stdout.reset() # type: ignore - if isinstance(rv, bytes): - rv = rv.decode("utf-8", "ignore") paragraphs = _paragraph_re.split(rv) if len(paragraphs) > 1: title = paragraphs[0] @@ -81,7 +81,7 @@ helper = _Helper() def _add_subclass_info( - inner: str, obj: object, base: t.Union[t.Type, t.Tuple[t.Type, ...]] + inner: str, obj: object, base: t.Type | tuple[t.Type, ...] ) -> str: if isinstance(base, tuple): for cls in base: @@ -97,8 +97,8 @@ def _add_subclass_info( def _sequence_repr_maker( left: str, right: str, base: t.Type, limit: int = 8 -) -> t.Callable[["DebugReprGenerator", t.Iterable, bool], str]: - def proxy(self: "DebugReprGenerator", obj: t.Iterable, recursive: bool) -> str: +) -> t.Callable[[DebugReprGenerator, t.Iterable, bool], str]: + def proxy(self: DebugReprGenerator, obj: t.Iterable, recursive: bool) -> str: if recursive: return _add_subclass_info(f"{left}...{right}", obj, base) buf = [left] @@ -120,7 +120,7 @@ def _sequence_repr_maker( class DebugReprGenerator: def __init__(self) -> None: - self._stack: t.List[t.Any] = [] + self._stack: list[t.Any] = [] list_repr = _sequence_repr_maker("[", "]", list) tuple_repr = _sequence_repr_maker("(", ")", tuple) @@ -132,11 +132,11 @@ class DebugReprGenerator: def regex_repr(self, obj: t.Pattern) -> str: pattern = repr(obj.pattern) - pattern = codecs.decode(pattern, "unicode-escape", "ignore") # type: ignore + pattern = codecs.decode(pattern, "unicode-escape", "ignore") pattern = f"r{pattern}" return f're.compile({pattern})' - def string_repr(self, obj: t.Union[str, bytes], limit: int = 70) -> str: + def string_repr(self, obj: str | bytes, limit: int = 70) -> str: buf = [''] r = repr(obj) @@ -165,7 +165,7 @@ class DebugReprGenerator: def dict_repr( self, - d: t.Union[t.Dict[int, None], t.Dict[str, int], t.Dict[t.Union[str, int], int]], + d: dict[int, None] | dict[str, int] | dict[str | int, int], recursive: bool, limit: int = 5, ) -> str: @@ -188,9 +188,7 @@ class DebugReprGenerator: buf.append("}") return _add_subclass_info("".join(buf), d, dict) - def object_repr( - self, obj: t.Optional[t.Union[t.Type[dict], t.Callable, t.Type[list]]] - ) -> str: + def object_repr(self, obj: type[dict] | t.Callable | type[list] | None) -> str: r = repr(obj) return f'{escape(r)}' @@ -244,7 +242,7 @@ class DebugReprGenerator: def dump_object(self, obj: object) -> str: repr = None - items: t.Optional[t.List[t.Tuple[str, str]]] = None + items: list[tuple[str, str]] | None = None if isinstance(obj, dict): title = "Contents of" @@ -266,12 +264,12 @@ class DebugReprGenerator: title += f" {object.__repr__(obj)[1:-1]}" return self.render_object_dump(items, title, repr) - def dump_locals(self, d: t.Dict[str, t.Any]) -> str: + def dump_locals(self, d: dict[str, t.Any]) -> str: items = [(key, self.repr(value)) for key, value in d.items()] return self.render_object_dump(items, "Local variables in frame") def render_object_dump( - self, items: t.List[t.Tuple[str, str]], title: str, repr: t.Optional[str] = None + self, items: list[tuple[str, str]], title: str, repr: str | None = None ) -> str: html_items = [] for key, value in items: diff --git a/src/werkzeug/debug/shared/debugger.js b/src/werkzeug/debug/shared/debugger.js index 2354f03..f463e9c 100644 --- a/src/werkzeug/debug/shared/debugger.js +++ b/src/werkzeug/debug/shared/debugger.js @@ -305,7 +305,8 @@ function handleConsoleSubmit(e, command, frameID) { wrapperSpan.append(spanToWrap); spanToWrap.hidden = true; - expansionButton.addEventListener("click", () => { + expansionButton.addEventListener("click", (event) => { + event.preventDefault(); spanToWrap.hidden = !spanToWrap.hidden; expansionButton.classList.toggle("open"); return false; diff --git a/src/werkzeug/debug/tbtools.py b/src/werkzeug/debug/tbtools.py index ea90de9..c45f56e 100644 --- a/src/werkzeug/debug/tbtools.py +++ b/src/werkzeug/debug/tbtools.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import itertools import linecache import os @@ -123,7 +125,7 @@ FRAME_HTML = """\ def _process_traceback( exc: BaseException, - te: t.Optional[traceback.TracebackException] = None, + te: traceback.TracebackException | None = None, *, skip: int = 0, hide: bool = True, @@ -146,7 +148,7 @@ def _process_traceback( frame_gen = itertools.islice(frame_gen, skip, None) del te.stack[:skip] - new_stack: t.List[DebugFrameSummary] = [] + new_stack: list[DebugFrameSummary] = [] hidden = False # Match each frame with the FrameSummary that was generated. @@ -175,7 +177,7 @@ def _process_traceback( elif hide_value or hidden: continue - frame_args: t.Dict[str, t.Any] = { + frame_args: dict[str, t.Any] = { "filename": fs.filename, "lineno": fs.lineno, "name": fs.name, @@ -184,7 +186,7 @@ def _process_traceback( } if hasattr(fs, "colno"): - frame_args["colno"] = fs.colno # type: ignore[attr-defined] + frame_args["colno"] = fs.colno frame_args["end_colno"] = fs.end_colno # type: ignore[attr-defined] new_stack.append(DebugFrameSummary(**frame_args)) @@ -221,7 +223,7 @@ class DebugTraceback: def __init__( self, exc: BaseException, - te: t.Optional[traceback.TracebackException] = None, + te: traceback.TracebackException | None = None, *, skip: int = 0, hide: bool = True, @@ -234,7 +236,7 @@ class DebugTraceback: @cached_property def all_tracebacks( self, - ) -> t.List[t.Tuple[t.Optional[str], traceback.TracebackException]]: + ) -> list[tuple[str | None, traceback.TracebackException]]: out = [] current = self._te @@ -261,7 +263,7 @@ class DebugTraceback: return out @cached_property - def all_frames(self) -> t.List["DebugFrameSummary"]: + def all_frames(self) -> list[DebugFrameSummary]: return [ f for _, te in self.all_tracebacks for f in te.stack # type: ignore[misc] ] @@ -325,7 +327,7 @@ class DebugTraceback: "evalex": "true" if evalex else "false", "evalex_trusted": "true" if evalex_trusted else "false", "console": "false", - "title": exc_lines[0], + "title": escape(exc_lines[0]), "exception": escape("".join(exc_lines)), "exception_type": escape(self._te.exc_type.__name__), "summary": self.render_traceback_html(include_title=False), @@ -351,8 +353,8 @@ class DebugFrameSummary(traceback.FrameSummary): def __init__( self, *, - locals: t.Dict[str, t.Any], - globals: t.Dict[str, t.Any], + locals: dict[str, t.Any], + globals: dict[str, t.Any], **kwargs: t.Any, ) -> None: super().__init__(locals=None, **kwargs) @@ -360,7 +362,7 @@ class DebugFrameSummary(traceback.FrameSummary): self.global_ns = globals @cached_property - def info(self) -> t.Optional[str]: + def info(self) -> str | None: return self.local_ns.get("__traceback_info__") @cached_property diff --git a/src/werkzeug/exceptions.py b/src/werkzeug/exceptions.py index 013df72..2536129 100644 --- a/src/werkzeug/exceptions.py +++ b/src/werkzeug/exceptions.py @@ -43,6 +43,8 @@ code, you can add a second except for a specific subclass of an error: return e """ +from __future__ import annotations + import typing as t from datetime import datetime @@ -52,13 +54,12 @@ from markupsafe import Markup from ._internal import _get_environ if t.TYPE_CHECKING: - import typing_extensions as te from _typeshed.wsgi import StartResponse from _typeshed.wsgi import WSGIEnvironment from .datastructures import WWWAuthenticate from .sansio.response import Response - from .wrappers.request import Request as WSGIRequest # noqa: F401 - from .wrappers.response import Response as WSGIResponse # noqa: F401 + from .wrappers.request import Request as WSGIRequest + from .wrappers.response import Response as WSGIResponse class HTTPException(Exception): @@ -70,13 +71,13 @@ class HTTPException(Exception): Removed the ``wrap`` class method. """ - code: t.Optional[int] = None - description: t.Optional[str] = None + code: int | None = None + description: str | None = None def __init__( self, - description: t.Optional[str] = None, - response: t.Optional["Response"] = None, + description: str | None = None, + response: Response | None = None, ) -> None: super().__init__() if description is not None: @@ -92,14 +93,12 @@ class HTTPException(Exception): def get_description( self, - environ: t.Optional["WSGIEnvironment"] = None, - scope: t.Optional[dict] = None, + environ: WSGIEnvironment | None = None, + scope: dict | None = None, ) -> str: """Get the description.""" if self.description is None: description = "" - elif not isinstance(self.description, str): - description = str(self.description) else: description = self.description @@ -108,8 +107,8 @@ class HTTPException(Exception): def get_body( self, - environ: t.Optional["WSGIEnvironment"] = None, - scope: t.Optional[dict] = None, + environ: WSGIEnvironment | None = None, + scope: dict | None = None, ) -> str: """Get the HTML body.""" return ( @@ -122,17 +121,17 @@ class HTTPException(Exception): def get_headers( self, - environ: t.Optional["WSGIEnvironment"] = None, - scope: t.Optional[dict] = None, - ) -> t.List[t.Tuple[str, str]]: + environ: WSGIEnvironment | None = None, + scope: dict | None = None, + ) -> list[tuple[str, str]]: """Get a list of headers.""" return [("Content-Type", "text/html; charset=utf-8")] def get_response( self, - environ: t.Optional[t.Union["WSGIEnvironment", "WSGIRequest"]] = None, - scope: t.Optional[dict] = None, - ) -> "Response": + environ: WSGIEnvironment | WSGIRequest | None = None, + scope: dict | None = None, + ) -> Response: """Get a response object. If one was passed to the exception it's returned directly. @@ -151,7 +150,7 @@ class HTTPException(Exception): return WSGIResponse(self.get_body(environ, scope), self.code, headers) def __call__( - self, environ: "WSGIEnvironment", start_response: "StartResponse" + self, environ: WSGIEnvironment, start_response: StartResponse ) -> t.Iterable[bytes]: """Call the exception as WSGI application. @@ -196,7 +195,7 @@ class BadRequestKeyError(BadRequest, KeyError): #: useful in a debug mode. show_exception = False - def __init__(self, arg: t.Optional[str] = None, *args: t.Any, **kwargs: t.Any): + def __init__(self, arg: str | None = None, *args: t.Any, **kwargs: t.Any): super().__init__(*args, **kwargs) if arg is None: @@ -205,7 +204,7 @@ class BadRequestKeyError(BadRequest, KeyError): KeyError.__init__(self, arg) @property # type: ignore - def description(self) -> str: # type: ignore + def description(self) -> str: if self.show_exception: return ( f"{self._description}\n" @@ -297,11 +296,9 @@ class Unauthorized(HTTPException): def __init__( self, - description: t.Optional[str] = None, - response: t.Optional["Response"] = None, - www_authenticate: t.Optional[ - t.Union["WWWAuthenticate", t.Iterable["WWWAuthenticate"]] - ] = None, + description: str | None = None, + response: Response | None = None, + www_authenticate: None | (WWWAuthenticate | t.Iterable[WWWAuthenticate]) = None, ) -> None: super().__init__(description, response) @@ -314,9 +311,9 @@ class Unauthorized(HTTPException): def get_headers( self, - environ: t.Optional["WSGIEnvironment"] = None, - scope: t.Optional[dict] = None, - ) -> t.List[t.Tuple[str, str]]: + environ: WSGIEnvironment | None = None, + scope: dict | None = None, + ) -> list[tuple[str, str]]: headers = super().get_headers(environ, scope) if self.www_authenticate: headers.extend(("WWW-Authenticate", str(x)) for x in self.www_authenticate) @@ -367,9 +364,9 @@ class MethodNotAllowed(HTTPException): def __init__( self, - valid_methods: t.Optional[t.Iterable[str]] = None, - description: t.Optional[str] = None, - response: t.Optional["Response"] = None, + valid_methods: t.Iterable[str] | None = None, + description: str | None = None, + response: Response | None = None, ) -> None: """Takes an optional list of valid http methods starting with werkzeug 0.3 the list will be mandatory.""" @@ -378,9 +375,9 @@ class MethodNotAllowed(HTTPException): def get_headers( self, - environ: t.Optional["WSGIEnvironment"] = None, - scope: t.Optional[dict] = None, - ) -> t.List[t.Tuple[str, str]]: + environ: WSGIEnvironment | None = None, + scope: dict | None = None, + ) -> list[tuple[str, str]]: headers = super().get_headers(environ, scope) if self.valid_methods: headers.append(("Allow", ", ".join(self.valid_methods))) @@ -524,10 +521,10 @@ class RequestedRangeNotSatisfiable(HTTPException): def __init__( self, - length: t.Optional[int] = None, + length: int | None = None, units: str = "bytes", - description: t.Optional[str] = None, - response: t.Optional["Response"] = None, + description: str | None = None, + response: Response | None = None, ) -> None: """Takes an optional `Content-Range` header value based on ``length`` parameter. @@ -538,9 +535,9 @@ class RequestedRangeNotSatisfiable(HTTPException): def get_headers( self, - environ: t.Optional["WSGIEnvironment"] = None, - scope: t.Optional[dict] = None, - ) -> t.List[t.Tuple[str, str]]: + environ: WSGIEnvironment | None = None, + scope: dict | None = None, + ) -> list[tuple[str, str]]: headers = super().get_headers(environ, scope) if self.length is not None: headers.append(("Content-Range", f"{self.units} */{self.length}")) @@ -638,18 +635,18 @@ class _RetryAfter(HTTPException): def __init__( self, - description: t.Optional[str] = None, - response: t.Optional["Response"] = None, - retry_after: t.Optional[t.Union[datetime, int]] = None, + description: str | None = None, + response: Response | None = None, + retry_after: datetime | int | None = None, ) -> None: super().__init__(description, response) self.retry_after = retry_after def get_headers( self, - environ: t.Optional["WSGIEnvironment"] = None, - scope: t.Optional[dict] = None, - ) -> t.List[t.Tuple[str, str]]: + environ: WSGIEnvironment | None = None, + scope: dict | None = None, + ) -> list[tuple[str, str]]: headers = super().get_headers(environ, scope) if self.retry_after: @@ -728,9 +725,9 @@ class InternalServerError(HTTPException): def __init__( self, - description: t.Optional[str] = None, - response: t.Optional["Response"] = None, - original_exception: t.Optional[BaseException] = None, + description: str | None = None, + response: Response | None = None, + original_exception: BaseException | None = None, ) -> None: #: The original exception that caused this 500 error. Can be #: used by frameworks to provide context when handling @@ -809,7 +806,7 @@ class HTTPVersionNotSupported(HTTPException): ) -default_exceptions: t.Dict[int, t.Type[HTTPException]] = {} +default_exceptions: dict[int, type[HTTPException]] = {} def _find_exceptions() -> None: @@ -841,8 +838,8 @@ class Aborter: def __init__( self, - mapping: t.Optional[t.Dict[int, t.Type[HTTPException]]] = None, - extra: t.Optional[t.Dict[int, t.Type[HTTPException]]] = None, + mapping: dict[int, type[HTTPException]] | None = None, + extra: dict[int, type[HTTPException]] | None = None, ) -> None: if mapping is None: mapping = default_exceptions @@ -851,8 +848,8 @@ class Aborter: self.mapping.update(extra) def __call__( - self, code: t.Union[int, "Response"], *args: t.Any, **kwargs: t.Any - ) -> "te.NoReturn": + self, code: int | Response, *args: t.Any, **kwargs: t.Any + ) -> t.NoReturn: from .sansio.response import Response if isinstance(code, Response): @@ -864,9 +861,7 @@ class Aborter: raise self.mapping[code](*args, **kwargs) -def abort( - status: t.Union[int, "Response"], *args: t.Any, **kwargs: t.Any -) -> "te.NoReturn": +def abort(status: int | Response, *args: t.Any, **kwargs: t.Any) -> t.NoReturn: """Raises an :py:exc:`HTTPException` for the given status code or WSGI application. diff --git a/src/werkzeug/formparser.py b/src/werkzeug/formparser.py index 10d58ca..ee30666 100644 --- a/src/werkzeug/formparser.py +++ b/src/werkzeug/formparser.py @@ -1,13 +1,14 @@ -import typing as t -from functools import update_wrapper -from io import BytesIO -from itertools import chain -from typing import Union +from __future__ import annotations -from . import exceptions +import typing as t +from io import BytesIO +from urllib.parse import parse_qsl + +from ._internal import _plain_int from .datastructures import FileStorage from .datastructures import Headers from .datastructures import MultiDict +from .exceptions import RequestEntityTooLarge from .http import parse_options_header from .sansio.multipart import Data from .sansio.multipart import Epilogue @@ -15,8 +16,6 @@ from .sansio.multipart import Field from .sansio.multipart import File from .sansio.multipart import MultipartDecoder from .sansio.multipart import NeedData -from .urls import url_decode_stream -from .wsgi import _make_chunk_iter from .wsgi import get_content_length from .wsgi import get_input_stream @@ -38,10 +37,10 @@ if t.TYPE_CHECKING: class TStreamFactory(te.Protocol): def __call__( self, - total_content_length: t.Optional[int], - content_type: t.Optional[str], - filename: t.Optional[str], - content_length: t.Optional[int] = None, + total_content_length: int | None, + content_type: str | None, + filename: str | None, + content_length: int | None = None, ) -> t.IO[bytes]: ... @@ -49,17 +48,11 @@ if t.TYPE_CHECKING: F = t.TypeVar("F", bound=t.Callable[..., t.Any]) -def _exhaust(stream: t.IO[bytes]) -> None: - bts = stream.read(64 * 1024) - while bts: - bts = stream.read(64 * 1024) - - def default_stream_factory( - total_content_length: t.Optional[int], - content_type: t.Optional[str], - filename: t.Optional[str], - content_length: t.Optional[int] = None, + total_content_length: int | None, + content_type: str | None, + filename: str | None, + content_length: int | None = None, ) -> t.IO[bytes]: max_size = 1024 * 500 @@ -72,15 +65,15 @@ def default_stream_factory( def parse_form_data( - environ: "WSGIEnvironment", - stream_factory: t.Optional["TStreamFactory"] = None, - charset: str = "utf-8", - errors: str = "replace", - max_form_memory_size: t.Optional[int] = None, - max_content_length: t.Optional[int] = None, - cls: t.Optional[t.Type[MultiDict]] = None, + environ: WSGIEnvironment, + stream_factory: TStreamFactory | None = None, + max_form_memory_size: int | None = None, + max_content_length: int | None = None, + cls: type[MultiDict] | None = None, silent: bool = True, -) -> "t_parse_result": + *, + max_form_parts: int | None = None, +) -> t_parse_result: """Parse the form data in the environ and return it as tuple in the form ``(stream, form, files)``. You should only call this method if the transport method is `POST`, `PUT`, or `PATCH`. @@ -92,21 +85,10 @@ def parse_form_data( This is a shortcut for the common usage of :class:`FormDataParser`. - Have a look at :doc:`/request_data` for more details. - - .. versionadded:: 0.5 - The `max_form_memory_size`, `max_content_length` and - `cls` parameters were added. - - .. versionadded:: 0.5.1 - The optional `silent` flag was added. - :param environ: the WSGI environment to be used for parsing. :param stream_factory: An optional callable that returns a new read and writeable file descriptor. This callable works the same as :meth:`Response._get_file_stream`. - :param charset: The character set for URL and url encoded form data. - :param errors: The encoding error behavior. :param max_form_memory_size: the maximum number of bytes to be accepted for in-memory stored form data. If the data exceeds the value specified an @@ -119,40 +101,33 @@ def parse_form_data( :param cls: an optional dict class to use. If this is not specified or `None` the default :class:`MultiDict` is used. :param silent: If set to False parsing errors will not be caught. + :param max_form_parts: The maximum number of multipart parts to be parsed. If this + is exceeded, a :exc:`~exceptions.RequestEntityTooLarge` exception is raised. :return: A tuple in the form ``(stream, form, files)``. + + .. versionchanged:: 3.0 + The ``charset`` and ``errors`` parameters were removed. + + .. versionchanged:: 2.3 + Added the ``max_form_parts`` parameter. + + .. versionadded:: 0.5.1 + Added the ``silent`` parameter. + + .. versionadded:: 0.5 + Added the ``max_form_memory_size``, ``max_content_length``, and ``cls`` + parameters. """ return FormDataParser( - stream_factory, - charset, - errors, - max_form_memory_size, - max_content_length, - cls, - silent, + stream_factory=stream_factory, + max_form_memory_size=max_form_memory_size, + max_content_length=max_content_length, + max_form_parts=max_form_parts, + silent=silent, + cls=cls, ).parse_from_environ(environ) -def exhaust_stream(f: F) -> F: - """Helper decorator for methods that exhausts the stream on return.""" - - def wrapper(self, stream, *args, **kwargs): # type: ignore - try: - return f(self, stream, *args, **kwargs) - finally: - exhaust = getattr(stream, "exhaust", None) - - if exhaust is not None: - exhaust() - else: - while True: - chunk = stream.read(1024 * 64) - - if not chunk: - break - - return update_wrapper(t.cast(F, wrapper), f) - - class FormDataParser: """This class implements parsing of form data for Werkzeug. By itself it can parse multipart and url encoded form data. It can be subclassed @@ -160,13 +135,9 @@ class FormDataParser: untouched stream and expose it as separate attributes on a request object. - .. versionadded:: 0.8 - :param stream_factory: An optional callable that returns a new read and writeable file descriptor. This callable works the same as :meth:`Response._get_file_stream`. - :param charset: The character set for URL and url encoded form data. - :param errors: The encoding error behavior. :param max_form_memory_size: the maximum number of bytes to be accepted for in-memory stored form data. If the data exceeds the value specified an @@ -179,26 +150,38 @@ class FormDataParser: :param cls: an optional dict class to use. If this is not specified or `None` the default :class:`MultiDict` is used. :param silent: If set to False parsing errors will not be caught. + :param max_form_parts: The maximum number of multipart parts to be parsed. If this + is exceeded, a :exc:`~exceptions.RequestEntityTooLarge` exception is raised. + + .. versionchanged:: 3.0 + The ``charset`` and ``errors`` parameters were removed. + + .. versionchanged:: 3.0 + The ``parse_functions`` attribute and ``get_parse_func`` methods were removed. + + .. versionchanged:: 2.2.3 + Added the ``max_form_parts`` parameter. + + .. versionadded:: 0.8 """ def __init__( self, - stream_factory: t.Optional["TStreamFactory"] = None, - charset: str = "utf-8", - errors: str = "replace", - max_form_memory_size: t.Optional[int] = None, - max_content_length: t.Optional[int] = None, - cls: t.Optional[t.Type[MultiDict]] = None, + stream_factory: TStreamFactory | None = None, + max_form_memory_size: int | None = None, + max_content_length: int | None = None, + cls: type[MultiDict] | None = None, silent: bool = True, + *, + max_form_parts: int | None = None, ) -> None: if stream_factory is None: stream_factory = default_stream_factory self.stream_factory = stream_factory - self.charset = charset - self.errors = errors self.max_form_memory_size = max_form_memory_size self.max_content_length = max_content_length + self.max_form_parts = max_form_parts if cls is None: cls = MultiDict @@ -206,34 +189,29 @@ class FormDataParser: self.cls = cls self.silent = silent - def get_parse_func( - self, mimetype: str, options: t.Dict[str, str] - ) -> t.Optional[ - t.Callable[ - ["FormDataParser", t.IO[bytes], str, t.Optional[int], t.Dict[str, str]], - "t_parse_result", - ] - ]: - return self.parse_functions.get(mimetype) - - def parse_from_environ(self, environ: "WSGIEnvironment") -> "t_parse_result": + def parse_from_environ(self, environ: WSGIEnvironment) -> t_parse_result: """Parses the information from the environment as form data. :param environ: the WSGI environment to be used for parsing. :return: A tuple in the form ``(stream, form, files)``. """ - content_type = environ.get("CONTENT_TYPE", "") + stream = get_input_stream(environ, max_content_length=self.max_content_length) content_length = get_content_length(environ) - mimetype, options = parse_options_header(content_type) - return self.parse(get_input_stream(environ), mimetype, content_length, options) + mimetype, options = parse_options_header(environ.get("CONTENT_TYPE")) + return self.parse( + stream, + content_length=content_length, + mimetype=mimetype, + options=options, + ) def parse( self, stream: t.IO[bytes], mimetype: str, - content_length: t.Optional[int], - options: t.Optional[t.Dict[str, str]] = None, - ) -> "t_parse_result": + content_length: int | None, + options: dict[str, str] | None = None, + ) -> t_parse_result: """Parses the information from the given stream, mimetype, content length and mimetype parameters. @@ -243,43 +221,40 @@ class FormDataParser: :param options: optional mimetype parameters (used for the multipart boundary for instance) :return: A tuple in the form ``(stream, form, files)``. + + .. versionchanged:: 3.0 + The invalid ``application/x-url-encoded`` content type is not + treated as ``application/x-www-form-urlencoded``. """ - if ( - self.max_content_length is not None - and content_length is not None - and content_length > self.max_content_length - ): - # if the input stream is not exhausted, firefox reports Connection Reset - _exhaust(stream) - raise exceptions.RequestEntityTooLarge() + if mimetype == "multipart/form-data": + parse_func = self._parse_multipart + elif mimetype == "application/x-www-form-urlencoded": + parse_func = self._parse_urlencoded + else: + return stream, self.cls(), self.cls() if options is None: options = {} - parse_func = self.get_parse_func(mimetype, options) - - if parse_func is not None: - try: - return parse_func(self, stream, mimetype, content_length, options) - except ValueError: - if not self.silent: - raise + try: + return parse_func(stream, mimetype, content_length, options) + except ValueError: + if not self.silent: + raise return stream, self.cls(), self.cls() - @exhaust_stream def _parse_multipart( self, stream: t.IO[bytes], mimetype: str, - content_length: t.Optional[int], - options: t.Dict[str, str], - ) -> "t_parse_result": + content_length: int | None, + options: dict[str, str], + ) -> t_parse_result: parser = MultiPartParser( - self.stream_factory, - self.charset, - self.errors, + stream_factory=self.stream_factory, max_form_memory_size=self.max_form_memory_size, + max_form_parts=self.max_form_parts, cls=self.cls, ) boundary = options.get("boundary", "").encode("ascii") @@ -290,66 +265,43 @@ class FormDataParser: form, files = parser.parse(stream, boundary, content_length) return stream, form, files - @exhaust_stream def _parse_urlencoded( self, stream: t.IO[bytes], mimetype: str, - content_length: t.Optional[int], - options: t.Dict[str, str], - ) -> "t_parse_result": + content_length: int | None, + options: dict[str, str], + ) -> t_parse_result: if ( self.max_form_memory_size is not None and content_length is not None and content_length > self.max_form_memory_size ): - # if the input stream is not exhausted, firefox reports Connection Reset - _exhaust(stream) - raise exceptions.RequestEntityTooLarge() + raise RequestEntityTooLarge() - form = url_decode_stream(stream, self.charset, errors=self.errors, cls=self.cls) - return stream, form, self.cls() + try: + items = parse_qsl( + stream.read().decode(), + keep_blank_values=True, + errors="werkzeug.url_quote", + ) + except ValueError as e: + raise RequestEntityTooLarge() from e - #: mapping of mimetypes to parsing functions - parse_functions: t.Dict[ - str, - t.Callable[ - ["FormDataParser", t.IO[bytes], str, t.Optional[int], t.Dict[str, str]], - "t_parse_result", - ], - ] = { - "multipart/form-data": _parse_multipart, - "application/x-www-form-urlencoded": _parse_urlencoded, - "application/x-url-encoded": _parse_urlencoded, - } - - -def _line_parse(line: str) -> t.Tuple[str, bool]: - """Removes line ending characters and returns a tuple (`stripped_line`, - `is_terminated`). - """ - if line[-2:] == "\r\n": - return line[:-2], True - - elif line[-1:] in {"\r", "\n"}: - return line[:-1], True - - return line, False + return stream, self.cls(items), self.cls() class MultiPartParser: def __init__( self, - stream_factory: t.Optional["TStreamFactory"] = None, - charset: str = "utf-8", - errors: str = "replace", - max_form_memory_size: t.Optional[int] = None, - cls: t.Optional[t.Type[MultiDict]] = None, + stream_factory: TStreamFactory | None = None, + max_form_memory_size: int | None = None, + cls: type[MultiDict] | None = None, buffer_size: int = 64 * 1024, + max_form_parts: int | None = None, ) -> None: - self.charset = charset - self.errors = errors self.max_form_memory_size = max_form_memory_size + self.max_form_parts = max_form_parts if stream_factory is None: stream_factory = default_stream_factory @@ -360,10 +312,9 @@ class MultiPartParser: cls = MultiDict self.cls = cls - self.buffer_size = buffer_size - def fail(self, message: str) -> "te.NoReturn": + def fail(self, message: str) -> te.NoReturn: raise ValueError(message) def get_part_charset(self, headers: Headers) -> str: @@ -371,18 +322,23 @@ class MultiPartParser: content_type = headers.get("content-type") if content_type: - mimetype, ct_params = parse_options_header(content_type) - return ct_params.get("charset", self.charset) + parameters = parse_options_header(content_type)[1] + ct_charset = parameters.get("charset", "").lower() - return self.charset + # A safe list of encodings. Modern clients should only send ASCII or UTF-8. + # This list will not be extended further. + if ct_charset in {"ascii", "us-ascii", "utf-8", "iso-8859-1"}: + return ct_charset + + return "utf-8" def start_file_streaming( - self, event: File, total_content_length: t.Optional[int] + self, event: File, total_content_length: int | None ) -> t.IO[bytes]: content_type = event.headers.get("content-type") try: - content_length = int(event.headers["content-length"]) + content_length = _plain_int(event.headers["content-length"]) except (KeyError, ValueError): content_length = 0 @@ -395,27 +351,22 @@ class MultiPartParser: return container def parse( - self, stream: t.IO[bytes], boundary: bytes, content_length: t.Optional[int] - ) -> t.Tuple[MultiDict, MultiDict]: - container: t.Union[t.IO[bytes], t.List[bytes]] + self, stream: t.IO[bytes], boundary: bytes, content_length: int | None + ) -> tuple[MultiDict, MultiDict]: + current_part: Field | File + container: t.IO[bytes] | list[bytes] _write: t.Callable[[bytes], t.Any] - iterator = chain( - _make_chunk_iter( - stream, - limit=content_length, - buffer_size=self.buffer_size, - ), - [None], + parser = MultipartDecoder( + boundary, + max_form_memory_size=self.max_form_memory_size, + max_parts=self.max_form_parts, ) - parser = MultipartDecoder(boundary, self.max_form_memory_size) - fields = [] files = [] - current_part: Union[Field, File] - for data in iterator: + for data in _chunk_iter(stream.read, self.buffer_size): parser.receive_data(data) event = parser.next_event() while not isinstance(event, (Epilogue, NeedData)): @@ -432,7 +383,7 @@ class MultiPartParser: if not event.more_data: if isinstance(current_part, Field): value = b"".join(container).decode( - self.get_part_charset(current_part.headers), self.errors + self.get_part_charset(current_part.headers), "replace" ) fields.append((current_part.name, value)) else: @@ -453,3 +404,18 @@ class MultiPartParser: event = parser.next_event() return self.cls(fields), self.cls(files) + + +def _chunk_iter(read: t.Callable[[int], bytes], size: int) -> t.Iterator[bytes | None]: + """Read data in chunks for multipart/form-data parsing. Stop if no data is read. + Yield ``None`` at the end to signal end of parsing. + """ + while True: + data = read(size) + + if not data: + break + + yield data + + yield None diff --git a/src/werkzeug/http.py b/src/werkzeug/http.py index 9777685..8280f51 100644 --- a/src/werkzeug/http.py +++ b/src/werkzeug/http.py @@ -1,7 +1,7 @@ -import base64 +from __future__ import annotations + import email.utils import re -import typing import typing as t import warnings from datetime import date @@ -13,74 +13,20 @@ from enum import Enum from hashlib import sha1 from time import mktime from time import struct_time -from urllib.parse import unquote_to_bytes as _unquote +from urllib.parse import quote +from urllib.parse import unquote from urllib.request import parse_http_list as _parse_list_header -from ._internal import _cookie_quote from ._internal import _dt_as_utc -from ._internal import _make_cookie_domain -from ._internal import _to_bytes -from ._internal import _to_str -from ._internal import _wsgi_decoding_dance +from ._internal import _plain_int if t.TYPE_CHECKING: from _typeshed.wsgi import WSGIEnvironment -# for explanation of "media-range", etc. see Sections 5.3.{1,2} of RFC 7231 -_accept_re = re.compile( - r""" - ( # media-range capturing-parenthesis - [^\s;,]+ # type/subtype - (?:[ \t]*;[ \t]* # ";" - (?: # parameter non-capturing-parenthesis - [^\s;,q][^\s;,]* # token that doesn't start with "q" - | # or - q[^\s;,=][^\s;,]* # token that is more than just "q" - ) - )* # zero or more parameters - ) # end of media-range - (?:[ \t]*;[ \t]*q= # weight is a "q" parameter - (\d*(?:\.\d+)?) # qvalue capturing-parentheses - [^,]* # "extension" accept params: who cares? - )? # accept params are optional - """, - re.VERBOSE, -) _token_chars = frozenset( "!#$%&'*+-.0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ^_`abcdefghijklmnopqrstuvwxyz|~" ) _etag_re = re.compile(r'([Ww]/)?(?:"(.*?)"|(.*?))(?:\s*,\s*|$)') -_option_header_piece_re = re.compile( - r""" - ;\s*,?\s* # newlines were replaced with commas - (?P - "[^"\\]*(?:\\.[^"\\]*)*" # quoted string - | - [^\s;,=*]+ # token - ) - (?:\*(?P\d+))? # *1, optional continuation index - \s* - (?: # optionally followed by =value - (?: # equals sign, possibly with encoding - \*\s*=\s* # * indicates extended notation - (?: # optional encoding - (?P[^\s]+?) - '(?P[^\s]*?)' - )? - | - =\s* # basic notation - ) - (?P - "[^"\\]*(?:\\.[^"\\]*)*" # quoted string - | - [^;,]+ # token - )? - )? - \s* - """, - flags=re.VERBOSE, -) -_option_header_start_mime_type = re.compile(r",\s*([^;,\s]+)([;,]\s*.+)?") _entity_headers = frozenset( [ "allow", @@ -190,108 +136,155 @@ class COOP(Enum): SAME_ORIGIN = "same-origin" -def quote_header_value( - value: t.Union[str, int], extra_chars: str = "", allow_token: bool = True -) -> str: - """Quote a header value if necessary. +def quote_header_value(value: t.Any, allow_token: bool = True) -> str: + """Add double quotes around a header value. If the header contains only ASCII token + characters, it will be returned unchanged. If the header contains ``"`` or ``\\`` + characters, they will be escaped with an additional ``\\`` character. + + This is the reverse of :func:`unquote_header_value`. + + :param value: The value to quote. Will be converted to a string. + :param allow_token: Disable to quote the value even if it only has token characters. + + .. versionchanged:: 3.0 + Passing bytes is not supported. + + .. versionchanged:: 3.0 + The ``extra_chars`` parameter is removed. + + .. versionchanged:: 2.3 + The value is quoted if it is the empty string. .. versionadded:: 0.5 - - :param value: the value to quote. - :param extra_chars: a list of extra characters to skip quoting. - :param allow_token: if this is enabled token values are returned - unchanged. """ - if isinstance(value, bytes): - value = value.decode("latin1") value = str(value) + + if not value: + return '""' + if allow_token: - token_chars = _token_chars | set(extra_chars) - if set(value).issubset(token_chars): + token_chars = _token_chars + + if token_chars.issuperset(value): return value + value = value.replace("\\", "\\\\").replace('"', '\\"') return f'"{value}"' -def unquote_header_value(value: str, is_filename: bool = False) -> str: - r"""Unquotes a header value. (Reversal of :func:`quote_header_value`). - This does not use the real unquoting but what browsers are actually - using for quoting. +def unquote_header_value(value: str) -> str: + """Remove double quotes and decode slash-escaped ``"`` and ``\\`` characters in a + header value. - .. versionadded:: 0.5 + This is the reverse of :func:`quote_header_value`. - :param value: the header value to unquote. - :param is_filename: The value represents a filename or path. + :param value: The header value to unquote. + + .. versionchanged:: 3.0 + The ``is_filename`` parameter is removed. """ - if value and value[0] == value[-1] == '"': - # this is not the real unquoting, but fixing this so that the - # RFC is met will result in bugs with internet explorer and - # probably some other browsers as well. IE for example is - # uploading files with "C:\foo\bar.txt" as filename + if len(value) >= 2 and value[0] == value[-1] == '"': value = value[1:-1] + return value.replace("\\\\", "\\").replace('\\"', '"') - # if this is a filename and the starting characters look like - # a UNC path, then just return the value without quotes. Using the - # replace sequence below on a UNC path has the effect of turning - # the leading double slash into a single slash and then - # _fix_ie_filename() doesn't work correctly. See #458. - if not is_filename or value[:2] != "\\\\": - return value.replace("\\\\", "\\").replace('\\"', '"') return value -def dump_options_header( - header: t.Optional[str], options: t.Mapping[str, t.Optional[t.Union[str, int]]] -) -> str: - """The reverse function to :func:`parse_options_header`. +def dump_options_header(header: str | None, options: t.Mapping[str, t.Any]) -> str: + """Produce a header value and ``key=value`` parameters separated by semicolons + ``;``. For example, the ``Content-Type`` header. - :param header: the header to dump - :param options: a dict of options to append. + .. code-block:: python + + dump_options_header("text/html", {"charset": "UTF-8"}) + 'text/html; charset=UTF-8' + + This is the reverse of :func:`parse_options_header`. + + If a value contains non-token characters, it will be quoted. + + If a value is ``None``, the parameter is skipped. + + In some keys for some headers, a UTF-8 value can be encoded using a special + ``key*=UTF-8''value`` form, where ``value`` is percent encoded. This function will + not produce that format automatically, but if a given key ends with an asterisk + ``*``, the value is assumed to have that form and will not be quoted further. + + :param header: The primary header value. + :param options: Parameters to encode as ``key=value`` pairs. + + .. versionchanged:: 2.3 + Keys with ``None`` values are skipped rather than treated as a bare key. + + .. versionchanged:: 2.2.3 + If a key ends with ``*``, its value will not be quoted. """ segments = [] + if header is not None: segments.append(header) + for key, value in options.items(): if value is None: - segments.append(key) + continue + + if key[-1] == "*": + segments.append(f"{key}={value}") else: segments.append(f"{key}={quote_header_value(value)}") + return "; ".join(segments) -def dump_header( - iterable: t.Union[t.Dict[str, t.Union[str, int]], t.Iterable[str]], - allow_token: bool = True, -) -> str: - """Dump an HTTP header again. This is the reversal of - :func:`parse_list_header`, :func:`parse_set_header` and - :func:`parse_dict_header`. This also quotes strings that include an - equals sign unless you pass it as dict of key, value pairs. +def dump_header(iterable: dict[str, t.Any] | t.Iterable[t.Any]) -> str: + """Produce a header value from a list of items or ``key=value`` pairs, separated by + commas ``,``. - >>> dump_header({'foo': 'bar baz'}) - 'foo="bar baz"' - >>> dump_header(('foo', 'bar baz')) - 'foo, "bar baz"' + This is the reverse of :func:`parse_list_header`, :func:`parse_dict_header`, and + :func:`parse_set_header`. - :param iterable: the iterable or dict of values to quote. - :param allow_token: if set to `False` tokens as values are disallowed. - See :func:`quote_header_value` for more details. + If a value contains non-token characters, it will be quoted. + + If a value is ``None``, the key is output alone. + + In some keys for some headers, a UTF-8 value can be encoded using a special + ``key*=UTF-8''value`` form, where ``value`` is percent encoded. This function will + not produce that format automatically, but if a given key ends with an asterisk + ``*``, the value is assumed to have that form and will not be quoted further. + + .. code-block:: python + + dump_header(["foo", "bar baz"]) + 'foo, "bar baz"' + + dump_header({"foo": "bar baz"}) + 'foo="bar baz"' + + :param iterable: The items to create a header from. + + .. versionchanged:: 3.0 + The ``allow_token`` parameter is removed. + + .. versionchanged:: 2.2.3 + If a key ends with ``*``, its value will not be quoted. """ if isinstance(iterable, dict): items = [] + for key, value in iterable.items(): if value is None: items.append(key) + elif key[-1] == "*": + items.append(f"{key}={value}") else: - items.append( - f"{key}={quote_header_value(value, allow_token=allow_token)}" - ) + items.append(f"{key}={quote_header_value(value)}") else: - items = [quote_header_value(x, allow_token=allow_token) for x in iterable] + items = [quote_header_value(x) for x in iterable] + return ", ".join(items) -def dump_csp_header(header: "ds.ContentSecurityPolicy") -> str: +def dump_csp_header(header: ds.ContentSecurityPolicy) -> str: """Dump a Content Security Policy header. These are structured into policies such as "default-src 'self'; @@ -304,187 +297,287 @@ def dump_csp_header(header: "ds.ContentSecurityPolicy") -> str: return "; ".join(f"{key} {value}" for key, value in header.items()) -def parse_list_header(value: str) -> t.List[str]: - """Parse lists as described by RFC 2068 Section 2. +def parse_list_header(value: str) -> list[str]: + """Parse a header value that consists of a list of comma separated items according + to `RFC 9110 `__. - In particular, parse comma-separated lists where the elements of - the list may include quoted-strings. A quoted-string could - contain a comma. A non-quoted string could have quotes in the - middle. Quotes are removed automatically after parsing. + This extends :func:`urllib.request.parse_http_list` to remove surrounding quotes + from values. - It basically works like :func:`parse_set_header` just that items - may appear multiple times and case sensitivity is preserved. + .. code-block:: python - The return value is a standard :class:`list`: + parse_list_header('token, "quoted value"') + ['token', 'quoted value'] - >>> parse_list_header('token, "quoted value"') - ['token', 'quoted value'] - - To create a header from the :class:`list` again, use the - :func:`dump_header` function. - - :param value: a string with a list header. - :return: :class:`list` - """ - result = [] - for item in _parse_list_header(value): - if item[:1] == item[-1:] == '"': - item = unquote_header_value(item[1:-1]) - result.append(item) - return result - - -def parse_dict_header(value: str, cls: t.Type[dict] = dict) -> t.Dict[str, str]: - """Parse lists of key, value pairs as described by RFC 2068 Section 2 and - convert them into a python dict (or any other mapping object created from - the type with a dict like interface provided by the `cls` argument): - - >>> d = parse_dict_header('foo="is a fish", bar="as well"') - >>> type(d) is dict - True - >>> sorted(d.items()) - [('bar', 'as well'), ('foo', 'is a fish')] - - If there is no value for a key it will be `None`: - - >>> parse_dict_header('key_without_value') - {'key_without_value': None} - - To create a header from the :class:`dict` again, use the - :func:`dump_header` function. - - .. versionchanged:: 0.9 - Added support for `cls` argument. - - :param value: a string with a dict header. - :param cls: callable to use for storage of parsed results. - :return: an instance of `cls` - """ - result = cls() - if isinstance(value, bytes): - value = value.decode("latin1") - for item in _parse_list_header(value): - if "=" not in item: - result[item] = None - continue - name, value = item.split("=", 1) - if value[:1] == value[-1:] == '"': - value = unquote_header_value(value[1:-1]) - result[name] = value - return result - - -def parse_options_header(value: t.Optional[str]) -> t.Tuple[str, t.Dict[str, str]]: - """Parse a ``Content-Type``-like header into a tuple with the - value and any options: - - >>> parse_options_header('text/html; charset=utf8') - ('text/html', {'charset': 'utf8'}) - - This should is not for ``Cache-Control``-like headers, which use a - different format. For those, use :func:`parse_dict_header`. + This is the reverse of :func:`dump_header`. :param value: The header value to parse. + """ + result = [] + + for item in _parse_list_header(value): + if len(item) >= 2 and item[0] == item[-1] == '"': + item = item[1:-1] + + result.append(item) + + return result + + +def parse_dict_header(value: str) -> dict[str, str | None]: + """Parse a list header using :func:`parse_list_header`, then parse each item as a + ``key=value`` pair. + + .. code-block:: python + + parse_dict_header('a=b, c="d, e", f') + {"a": "b", "c": "d, e", "f": None} + + This is the reverse of :func:`dump_header`. + + If a key does not have a value, it is ``None``. + + This handles charsets for values as described in + `RFC 2231 `__. Only ASCII, UTF-8, + and ISO-8859-1 charsets are accepted, otherwise the value remains quoted. + + :param value: The header value to parse. + + .. versionchanged:: 3.0 + Passing bytes is not supported. + + .. versionchanged:: 3.0 + The ``cls`` argument is removed. + + .. versionchanged:: 2.3 + Added support for ``key*=charset''value`` encoded items. + + .. versionchanged:: 0.9 + The ``cls`` argument was added. + """ + result: dict[str, str | None] = {} + + for item in parse_list_header(value): + key, has_value, value = item.partition("=") + key = key.strip() + + if not has_value: + result[key] = None + continue + + value = value.strip() + encoding: str | None = None + + if key[-1] == "*": + # key*=charset''value becomes key=value, where value is percent encoded + # adapted from parse_options_header, without the continuation handling + key = key[:-1] + match = _charset_value_re.match(value) + + if match: + # If there is a charset marker in the value, split it off. + encoding, value = match.groups() + encoding = encoding.lower() + + # A safe list of encodings. Modern clients should only send ASCII or UTF-8. + # This list will not be extended further. An invalid encoding will leave the + # value quoted. + if encoding in {"ascii", "us-ascii", "utf-8", "iso-8859-1"}: + # invalid bytes are replaced during unquoting + value = unquote(value, encoding=encoding) + + if len(value) >= 2 and value[0] == value[-1] == '"': + value = value[1:-1] + + result[key] = value + + return result + + +# https://httpwg.org/specs/rfc9110.html#parameter +_parameter_re = re.compile( + r""" + # don't match multiple empty parts, that causes backtracking + \s*;\s* # find the part delimiter + (?: + ([\w!#$%&'*+\-.^`|~]+) # key, one or more token chars + = # equals, with no space on either side + ( # value, token or quoted string + [\w!#$%&'*+\-.^`|~]+ # one or more token chars + | + "(?:\\\\|\\"|.)*?" # quoted string, consuming slash escapes + ) + )? # optionally match key=value, to account for empty parts + """, + re.ASCII | re.VERBOSE, +) +# https://www.rfc-editor.org/rfc/rfc2231#section-4 +_charset_value_re = re.compile( + r""" + ([\w!#$%&*+\-.^`|~]*)' # charset part, could be empty + [\w!#$%&*+\-.^`|~]*' # don't care about language part, usually empty + ([\w!#$%&'*+\-.^`|~]+) # one or more token chars with percent encoding + """, + re.ASCII | re.VERBOSE, +) +# https://www.rfc-editor.org/rfc/rfc2231#section-3 +_continuation_re = re.compile(r"\*(\d+)$", re.ASCII) + + +def parse_options_header(value: str | None) -> tuple[str, dict[str, str]]: + """Parse a header that consists of a value with ``key=value`` parameters separated + by semicolons ``;``. For example, the ``Content-Type`` header. + + .. code-block:: python + + parse_options_header("text/html; charset=UTF-8") + ('text/html', {'charset': 'UTF-8'}) + + parse_options_header("") + ("", {}) + + This is the reverse of :func:`dump_options_header`. + + This parses valid parameter parts as described in + `RFC 9110 `__. Invalid parts are + skipped. + + This handles continuations and charsets as described in + `RFC 2231 `__, although not as + strictly as the RFC. Only ASCII, UTF-8, and ISO-8859-1 charsets are accepted, + otherwise the value remains quoted. + + Clients may not be consistent in how they handle a quote character within a quoted + value. The `HTML Standard `__ + replaces it with ``%22`` in multipart form data. + `RFC 9110 `__ uses backslash + escapes in HTTP headers. Both are decoded to the ``"`` character. + + Clients may not be consistent in how they handle non-ASCII characters. HTML + documents must declare ````, otherwise browsers may replace with + HTML character references, which can be decoded using :func:`html.unescape`. + + :param value: The header value to parse. + :return: ``(value, options)``, where ``options`` is a dict + + .. versionchanged:: 2.3 + Invalid parts, such as keys with no value, quoted keys, and incorrectly quoted + values, are discarded instead of treating as ``None``. + + .. versionchanged:: 2.3 + Only ASCII, UTF-8, and ISO-8859-1 are accepted for charset values. + + .. versionchanged:: 2.3 + Escaped quotes in quoted values, like ``%22`` and ``\\"``, are handled. .. versionchanged:: 2.2 Option names are always converted to lowercase. - .. versionchanged:: 2.1 - The ``multiple`` parameter is deprecated and will be removed in - Werkzeug 2.2. + .. versionchanged:: 2.2 + The ``multiple`` parameter was removed. .. versionchanged:: 0.15 :rfc:`2231` parameter continuations are handled. .. versionadded:: 0.5 """ - if not value: + if value is None: return "", {} - result: t.List[t.Any] = [] + value, _, rest = value.partition(";") + value = value.strip() + rest = rest.strip() - value = "," + value.replace("\n", ",") - while value: - match = _option_header_start_mime_type.match(value) - if not match: - break - result.append(match.group(1)) # mimetype - options: t.Dict[str, str] = {} - # Parse options - rest = match.group(2) - encoding: t.Optional[str] - continued_encoding: t.Optional[str] = None - while rest: - optmatch = _option_header_piece_re.match(rest) - if not optmatch: - break - option, count, encoding, language, option_value = optmatch.groups() - # Continuations don't have to supply the encoding after the - # first line. If we're in a continuation, track the current - # encoding to use for subsequent lines. Reset it when the - # continuation ends. - if not count: - continued_encoding = None - else: - if not encoding: - encoding = continued_encoding + if not value or not rest: + # empty (invalid) value, or value without options + return value, {} + + rest = f";{rest}" + options: dict[str, str] = {} + encoding: str | None = None + continued_encoding: str | None = None + + for pk, pv in _parameter_re.findall(rest): + if not pk: + # empty or invalid part + continue + + pk = pk.lower() + + if pk[-1] == "*": + # key*=charset''value becomes key=value, where value is percent encoded + pk = pk[:-1] + match = _charset_value_re.match(pv) + + if match: + # If there is a valid charset marker in the value, split it off. + encoding, pv = match.groups() + # This might be the empty string, handled next. + encoding = encoding.lower() + + # No charset marker, or marker with empty charset value. + if not encoding: + encoding = continued_encoding + + # A safe list of encodings. Modern clients should only send ASCII or UTF-8. + # This list will not be extended further. An invalid encoding will leave the + # value quoted. + if encoding in {"ascii", "us-ascii", "utf-8", "iso-8859-1"}: + # Continuation parts don't require their own charset marker. This is + # looser than the RFC, it will persist across different keys and allows + # changing the charset during a continuation. But this implementation is + # much simpler than tracking the full state. continued_encoding = encoding - option = unquote_header_value(option).lower() + # invalid bytes are replaced during unquoting + pv = unquote(pv, encoding=encoding) - if option_value is not None: - option_value = unquote_header_value(option_value, option == "filename") + # Remove quotes. At this point the value cannot be empty or a single quote. + if pv[0] == pv[-1] == '"': + # HTTP headers use slash, multipart form data uses percent + pv = pv[1:-1].replace("\\\\", "\\").replace('\\"', '"').replace("%22", '"') - if encoding is not None: - option_value = _unquote(option_value).decode(encoding) + match = _continuation_re.search(pk) - if count: - # Continuations append to the existing value. For - # simplicity, this ignores the possibility of - # out-of-order indices, which shouldn't happen anyway. - if option_value is not None: - options[option] = options.get(option, "") + option_value - else: - options[option] = option_value # type: ignore[assignment] + if match: + # key*0=a; key*1=b becomes key=ab + pk = pk[: match.start()] + options[pk] = options.get(pk, "") + pv + else: + options[pk] = pv - rest = rest[optmatch.end() :] - result.append(options) - return tuple(result) # type: ignore[return-value] - - return tuple(result) if result else ("", {}) # type: ignore[return-value] + return value, options +_q_value_re = re.compile(r"-?\d+(\.\d+)?", re.ASCII) _TAnyAccept = t.TypeVar("_TAnyAccept", bound="ds.Accept") -@typing.overload -def parse_accept_header(value: t.Optional[str]) -> "ds.Accept": +@t.overload +def parse_accept_header(value: str | None) -> ds.Accept: ... -@typing.overload -def parse_accept_header( - value: t.Optional[str], cls: t.Type[_TAnyAccept] -) -> _TAnyAccept: +@t.overload +def parse_accept_header(value: str | None, cls: type[_TAnyAccept]) -> _TAnyAccept: ... def parse_accept_header( - value: t.Optional[str], cls: t.Optional[t.Type[_TAnyAccept]] = None + value: str | None, cls: type[_TAnyAccept] | None = None ) -> _TAnyAccept: - """Parses an HTTP Accept-* header. This does not implement a complete - valid algorithm but one that supports at least value and quality - extraction. + """Parse an ``Accept`` header according to + `RFC 9110 `__. - Returns a new :class:`Accept` object (basically a list of ``(value, quality)`` - tuples sorted by the quality with some additional accessor methods). + Returns an :class:`.Accept` instance, which can sort and inspect items based on + their quality parameter. When parsing ``Accept-Charset``, ``Accept-Encoding``, or + ``Accept-Language``, pass the appropriate :class:`.Accept` subclass. - The second parameter can be a subclass of :class:`Accept` that is created - with the parsed values and returned. + :param value: The header value to parse. + :param cls: The :class:`.Accept` class to wrap the result in. + :return: An instance of ``cls``. - :param value: the accept header string to be parsed. - :param cls: the wrapper class for the return value (can be - :class:`Accept` or a subclass thereof) - :return: an instance of `cls`. + .. versionchanged:: 2.3 + Parse according to RFC 9110. Items with invalid ``q`` values are skipped. """ if cls is None: cls = t.cast(t.Type[_TAnyAccept], ds.Accept) @@ -493,38 +586,57 @@ def parse_accept_header( return cls(None) result = [] - for match in _accept_re.finditer(value): - quality_match = match.group(2) - if not quality_match: - quality: float = 1 + + for item in parse_list_header(value): + item, options = parse_options_header(item) + + if "q" in options: + # pop q, remaining options are reconstructed + q_str = options.pop("q").strip() + + if _q_value_re.fullmatch(q_str) is None: + # ignore an invalid q + continue + + q = float(q_str) + + if q < 0 or q > 1: + # ignore an invalid q + continue else: - quality = max(min(float(quality_match), 1), 0) - result.append((match.group(1), quality)) + q = 1 + + if options: + # reconstruct the media type with any options + item = dump_options_header(item, options) + + result.append((item, q)) + return cls(result) -_TAnyCC = t.TypeVar("_TAnyCC", bound="ds._CacheControl") +_TAnyCC = t.TypeVar("_TAnyCC", bound="ds.cache_control._CacheControl") _t_cc_update = t.Optional[t.Callable[[_TAnyCC], None]] -@typing.overload +@t.overload def parse_cache_control_header( - value: t.Optional[str], on_update: _t_cc_update, cls: None = None -) -> "ds.RequestCacheControl": + value: str | None, on_update: _t_cc_update, cls: None = None +) -> ds.RequestCacheControl: ... -@typing.overload +@t.overload def parse_cache_control_header( - value: t.Optional[str], on_update: _t_cc_update, cls: t.Type[_TAnyCC] + value: str | None, on_update: _t_cc_update, cls: type[_TAnyCC] ) -> _TAnyCC: ... def parse_cache_control_header( - value: t.Optional[str], + value: str | None, on_update: _t_cc_update = None, - cls: t.Optional[t.Type[_TAnyCC]] = None, + cls: type[_TAnyCC] | None = None, ) -> _TAnyCC: """Parse a cache control header. The RFC differs between response and request cache control, this method does not. It's your responsibility @@ -555,24 +667,24 @@ _TAnyCSP = t.TypeVar("_TAnyCSP", bound="ds.ContentSecurityPolicy") _t_csp_update = t.Optional[t.Callable[[_TAnyCSP], None]] -@typing.overload +@t.overload def parse_csp_header( - value: t.Optional[str], on_update: _t_csp_update, cls: None = None -) -> "ds.ContentSecurityPolicy": + value: str | None, on_update: _t_csp_update, cls: None = None +) -> ds.ContentSecurityPolicy: ... -@typing.overload +@t.overload def parse_csp_header( - value: t.Optional[str], on_update: _t_csp_update, cls: t.Type[_TAnyCSP] + value: str | None, on_update: _t_csp_update, cls: type[_TAnyCSP] ) -> _TAnyCSP: ... def parse_csp_header( - value: t.Optional[str], + value: str | None, on_update: _t_csp_update = None, - cls: t.Optional[t.Type[_TAnyCSP]] = None, + cls: type[_TAnyCSP] | None = None, ) -> _TAnyCSP: """Parse a Content Security Policy header. @@ -606,9 +718,9 @@ def parse_csp_header( def parse_set_header( - value: t.Optional[str], - on_update: t.Optional[t.Callable[["ds.HeaderSet"], None]] = None, -) -> "ds.HeaderSet": + value: str | None, + on_update: t.Callable[[ds.HeaderSet], None] | None = None, +) -> ds.HeaderSet: """Parse a set-like header and return a :class:`~werkzeug.datastructures.HeaderSet` object: @@ -638,76 +750,7 @@ def parse_set_header( return ds.HeaderSet(parse_list_header(value), on_update) -def parse_authorization_header( - value: t.Optional[str], -) -> t.Optional["ds.Authorization"]: - """Parse an HTTP basic/digest authorization header transmitted by the web - browser. The return value is either `None` if the header was invalid or - not given, otherwise an :class:`~werkzeug.datastructures.Authorization` - object. - - :param value: the authorization header to parse. - :return: a :class:`~werkzeug.datastructures.Authorization` object or `None`. - """ - if not value: - return None - value = _wsgi_decoding_dance(value) - try: - auth_type, auth_info = value.split(None, 1) - auth_type = auth_type.lower() - except ValueError: - return None - if auth_type == "basic": - try: - username, password = base64.b64decode(auth_info).split(b":", 1) - except Exception: - return None - try: - return ds.Authorization( - "basic", - { - "username": _to_str(username, "utf-8"), - "password": _to_str(password, "utf-8"), - }, - ) - except UnicodeDecodeError: - return None - elif auth_type == "digest": - auth_map = parse_dict_header(auth_info) - for key in "username", "realm", "nonce", "uri", "response": - if key not in auth_map: - return None - if "qop" in auth_map: - if not auth_map.get("nc") or not auth_map.get("cnonce"): - return None - return ds.Authorization("digest", auth_map) - return None - - -def parse_www_authenticate_header( - value: t.Optional[str], - on_update: t.Optional[t.Callable[["ds.WWWAuthenticate"], None]] = None, -) -> "ds.WWWAuthenticate": - """Parse an HTTP WWW-Authenticate header into a - :class:`~werkzeug.datastructures.WWWAuthenticate` object. - - :param value: a WWW-Authenticate header to parse. - :param on_update: an optional callable that is called every time a value - on the :class:`~werkzeug.datastructures.WWWAuthenticate` - object is changed. - :return: a :class:`~werkzeug.datastructures.WWWAuthenticate` object. - """ - if not value: - return ds.WWWAuthenticate(on_update=on_update) - try: - auth_type, auth_info = value.split(None, 1) - auth_type = auth_type.lower() - except (ValueError, AttributeError): - return ds.WWWAuthenticate(value.strip().lower(), on_update=on_update) - return ds.WWWAuthenticate(auth_type, parse_dict_header(auth_info), on_update) - - -def parse_if_range_header(value: t.Optional[str]) -> "ds.IfRange": +def parse_if_range_header(value: str | None) -> ds.IfRange: """Parses an if-range header which can be an etag or a date. Returns a :class:`~werkzeug.datastructures.IfRange` object. @@ -726,8 +769,8 @@ def parse_if_range_header(value: t.Optional[str]) -> "ds.IfRange": def parse_range_header( - value: t.Optional[str], make_inclusive: bool = True -) -> t.Optional["ds.Range"]: + value: str | None, make_inclusive: bool = True +) -> ds.Range | None: """Parses a range header into a :class:`~werkzeug.datastructures.Range` object. If the header is missing or malformed `None` is returned. `ranges` is a list of ``(start, stop)`` tuples where the ranges are @@ -751,7 +794,7 @@ def parse_range_header( if last_end < 0: return None try: - begin = int(item) + begin = _plain_int(item) except ValueError: return None end = None @@ -762,7 +805,7 @@ def parse_range_header( end_str = end_str.strip() try: - begin = int(begin_str) + begin = _plain_int(begin_str) except ValueError: return None @@ -770,7 +813,7 @@ def parse_range_header( return None if end_str: try: - end = int(end_str) + 1 + end = _plain_int(end_str) + 1 except ValueError: return None @@ -785,9 +828,9 @@ def parse_range_header( def parse_content_range_header( - value: t.Optional[str], - on_update: t.Optional[t.Callable[["ds.ContentRange"], None]] = None, -) -> t.Optional["ds.ContentRange"]: + value: str | None, + on_update: t.Callable[[ds.ContentRange], None] | None = None, +) -> ds.ContentRange | None: """Parses a range header into a :class:`~werkzeug.datastructures.ContentRange` object or `None` if parsing is not possible. @@ -813,19 +856,22 @@ def parse_content_range_header( length = None else: try: - length = int(length_str) + length = _plain_int(length_str) except ValueError: return None if rng == "*": + if not is_byte_range_valid(None, None, length): + return None + return ds.ContentRange(units, None, None, length, on_update=on_update) elif "-" not in rng: return None start_str, stop_str = rng.split("-", 1) try: - start = int(start_str) - stop = int(stop_str) + 1 + start = _plain_int(start_str) + stop = _plain_int(stop_str) + 1 except ValueError: return None @@ -850,8 +896,8 @@ def quote_etag(etag: str, weak: bool = False) -> str: def unquote_etag( - etag: t.Optional[str], -) -> t.Union[t.Tuple[str, bool], t.Tuple[None, None]]: + etag: str | None, +) -> tuple[str, bool] | tuple[None, None]: """Unquote a single etag: >>> unquote_etag('W/"bar"') @@ -874,7 +920,7 @@ def unquote_etag( return etag, weak -def parse_etags(value: t.Optional[str]) -> "ds.ETags": +def parse_etags(value: str | None) -> ds.ETags: """Parse an etag header. :param value: the tag header to parse @@ -912,7 +958,7 @@ def generate_etag(data: bytes) -> str: return sha1(data).hexdigest() -def parse_date(value: t.Optional[str]) -> t.Optional[datetime]: +def parse_date(value: str | None) -> datetime | None: """Parse an :rfc:`2822` date into a timezone-aware :class:`datetime.datetime` object, or ``None`` if parsing fails. @@ -942,7 +988,7 @@ def parse_date(value: t.Optional[str]) -> t.Optional[datetime]: def http_date( - timestamp: t.Optional[t.Union[datetime, date, int, float, struct_time]] = None + timestamp: datetime | date | int | float | struct_time | None = None, ) -> str: """Format a datetime object or timestamp into an :rfc:`2822` date string. @@ -973,7 +1019,7 @@ def http_date( return email.utils.formatdate(timestamp, usegmt=True) -def parse_age(value: t.Optional[str] = None) -> t.Optional[timedelta]: +def parse_age(value: str | None = None) -> timedelta | None: """Parses a base-10 integer count of seconds into a timedelta. If parsing fails, the return value is `None`. @@ -995,7 +1041,7 @@ def parse_age(value: t.Optional[str] = None) -> t.Optional[timedelta]: return None -def dump_age(age: t.Optional[t.Union[timedelta, int]] = None) -> t.Optional[str]: +def dump_age(age: timedelta | int | None = None) -> str | None: """Formats the duration as a base-10 integer. :param age: should be an integer number of seconds, @@ -1016,10 +1062,10 @@ def dump_age(age: t.Optional[t.Union[timedelta, int]] = None) -> t.Optional[str] def is_resource_modified( - environ: "WSGIEnvironment", - etag: t.Optional[str] = None, - data: t.Optional[bytes] = None, - last_modified: t.Optional[t.Union[datetime, str]] = None, + environ: WSGIEnvironment, + etag: str | None = None, + data: bytes | None = None, + last_modified: datetime | str | None = None, ignore_if_range: bool = True, ) -> bool: """Convenience method for conditional requests. @@ -1054,7 +1100,7 @@ def is_resource_modified( def remove_entity_headers( - headers: t.Union["ds.Headers", t.List[t.Tuple[str, str]]], + headers: ds.Headers | list[tuple[str, str]], allowed: t.Iterable[str] = ("expires", "content-location"), ) -> None: """Remove all entity headers from a list or :class:`Headers` object. This @@ -1077,9 +1123,7 @@ def remove_entity_headers( ] -def remove_hop_by_hop_headers( - headers: t.Union["ds.Headers", t.List[t.Tuple[str, str]]] -) -> None: +def remove_hop_by_hop_headers(headers: ds.Headers | list[tuple[str, str]]) -> None: """Remove all HTTP/1.1 "Hop-by-Hop" headers from a list or :class:`Headers` object. This operation works in-place. @@ -1115,11 +1159,9 @@ def is_hop_by_hop_header(header: str) -> bool: def parse_cookie( - header: t.Union["WSGIEnvironment", str, bytes, None], - charset: str = "utf-8", - errors: str = "replace", - cls: t.Optional[t.Type["ds.MultiDict"]] = None, -) -> "ds.MultiDict[str, str]": + header: WSGIEnvironment | str | None, + cls: type[ds.MultiDict] | None = None, +) -> ds.MultiDict[str, str]: """Parse a cookie from a string or WSGI environ. The same key can be provided multiple times, the values are stored @@ -1129,44 +1171,51 @@ def parse_cookie( :param header: The cookie header as a string, or a WSGI environ dict with a ``HTTP_COOKIE`` key. - :param charset: The charset for the cookie values. - :param errors: The error behavior for the charset decoding. :param cls: A dict-like class to store the parsed cookies in. Defaults to :class:`MultiDict`. - .. versionchanged:: 1.0.0 - Returns a :class:`MultiDict` instead of a - ``TypeConversionDict``. + .. versionchanged:: 3.0 + Passing bytes, and the ``charset`` and ``errors`` parameters, were removed. + + .. versionchanged:: 1.0 + Returns a :class:`MultiDict` instead of a ``TypeConversionDict``. .. versionchanged:: 0.5 - Returns a :class:`TypeConversionDict` instead of a regular dict. - The ``cls`` parameter was added. + Returns a :class:`TypeConversionDict` instead of a regular dict. The ``cls`` + parameter was added. """ if isinstance(header, dict): - cookie = header.get("HTTP_COOKIE", "") - elif header is None: - cookie = "" + cookie = header.get("HTTP_COOKIE") else: cookie = header - return _sansio_http.parse_cookie( - cookie=cookie, charset=charset, errors=errors, cls=cls - ) + if cookie: + cookie = cookie.encode("latin1").decode() + + return _sansio_http.parse_cookie(cookie=cookie, cls=cls) + + +_cookie_no_quote_re = re.compile(r"[\w!#$%&'()*+\-./:<=>?@\[\]^`{|}~]*", re.A) +_cookie_slash_re = re.compile(rb"[\x00-\x19\",;\\\x7f-\xff]", re.A) +_cookie_slash_map = {b'"': b'\\"', b"\\": b"\\\\"} +_cookie_slash_map.update( + (v.to_bytes(1, "big"), b"\\%03o" % v) + for v in [*range(0x20), *b",;", *range(0x7F, 256)] +) def dump_cookie( key: str, - value: t.Union[bytes, str] = "", - max_age: t.Optional[t.Union[timedelta, int]] = None, - expires: t.Optional[t.Union[str, datetime, int, float]] = None, - path: t.Optional[str] = "/", - domain: t.Optional[str] = None, + value: str = "", + max_age: timedelta | int | None = None, + expires: str | datetime | int | float | None = None, + path: str | None = "/", + domain: str | None = None, secure: bool = False, httponly: bool = False, - charset: str = "utf-8", sync_expires: bool = True, max_size: int = 4093, - samesite: t.Optional[str] = None, + samesite: str | None = None, ) -> str: """Create a Set-Cookie header without the ``Set-Cookie`` prefix. @@ -1187,7 +1236,7 @@ def dump_cookie( :param path: limits the cookie to a given path, per default it will span the whole domain. :param domain: Use this if you want to set a cross-domain cookie. For - example, ``domain=".example.com"`` will set a cookie + example, ``domain="example.com"`` will set a cookie that is readable by the domain ``www.example.com``, ``foo.example.com`` etc. Otherwise, a cookie will only be readable by the domain that set it. @@ -1206,18 +1255,33 @@ def dump_cookie( .. _`cookie`: http://browsercookielimits.squawky.net/ + .. versionchanged:: 3.0 + Passing bytes, and the ``charset`` parameter, were removed. + + .. versionchanged:: 2.3.3 + The ``path`` parameter is ``/`` by default. + + .. versionchanged:: 2.3.1 + The value allows more characters without quoting. + + .. versionchanged:: 2.3 + ``localhost`` and other names without a dot are allowed for the domain. A + leading dot is ignored. + + .. versionchanged:: 2.3 + The ``path`` parameter is ``None`` by default. + .. versionchanged:: 1.0.0 The string ``'None'`` is accepted for ``samesite``. """ - key = _to_bytes(key, charset) - value = _to_bytes(value, charset) - if path is not None: - from .urls import iri_to_uri + # safe = https://url.spec.whatwg.org/#url-path-segment-string + # as well as percent for things that are already quoted + # excluding semicolon since it's part of the header syntax + path = quote(path, safe="%!$&'()*+,/:=@") - path = iri_to_uri(path, charset) - - domain = _make_cookie_domain(domain) + if domain: + domain = domain.partition(":")[0].lstrip(".").encode("idna").decode("ascii") if isinstance(max_age, timedelta): max_age = int(max_age.total_seconds()) @@ -1234,54 +1298,51 @@ def dump_cookie( if samesite not in {"Strict", "Lax", "None"}: raise ValueError("SameSite must be 'Strict', 'Lax', or 'None'.") - buf = [key + b"=" + _cookie_quote(value)] + # Quote value if it contains characters not allowed by RFC 6265. Slash-escape with + # three octal digits, which matches http.cookies, although the RFC suggests base64. + if not _cookie_no_quote_re.fullmatch(value): + # Work with bytes here, since a UTF-8 character could be multiple bytes. + value = _cookie_slash_re.sub( + lambda m: _cookie_slash_map[m.group()], value.encode() + ).decode("ascii") + value = f'"{value}"' - # XXX: In theory all of these parameters that are not marked with `None` - # should be quoted. Because stdlib did not quote it before I did not - # want to introduce quoting there now. - for k, v, q in ( - (b"Domain", domain, True), - (b"Expires", expires, False), - (b"Max-Age", max_age, False), - (b"Secure", secure, None), - (b"HttpOnly", httponly, None), - (b"Path", path, False), - (b"SameSite", samesite, False), + # Send a non-ASCII key as mojibake. Everything else should already be ASCII. + # TODO Remove encoding dance, it seems like clients accept UTF-8 keys + buf = [f"{key.encode().decode('latin1')}={value}"] + + for k, v in ( + ("Domain", domain), + ("Expires", expires), + ("Max-Age", max_age), + ("Secure", secure), + ("HttpOnly", httponly), + ("Path", path), + ("SameSite", samesite), ): - if q is None: - if v: - buf.append(k) + if v is None or v is False: continue - if v is None: + if v is True: + buf.append(k) continue - tmp = bytearray(k) - if not isinstance(v, (bytes, bytearray)): - v = _to_bytes(str(v), charset) - if q: - v = _cookie_quote(v) - tmp += b"=" + v - buf.append(bytes(tmp)) + buf.append(f"{k}={v}") - # The return value will be an incorrectly encoded latin1 header for - # consistency with the headers object. - rv = b"; ".join(buf) - rv = rv.decode("latin1") + rv = "; ".join(buf) - # Warn if the final value of the cookie is larger than the limit. If the - # cookie is too large, then it may be silently ignored by the browser, - # which can be quite hard to debug. + # Warn if the final value of the cookie is larger than the limit. If the cookie is + # too large, then it may be silently ignored by the browser, which can be quite hard + # to debug. cookie_size = len(rv) if max_size and cookie_size > max_size: value_size = len(value) warnings.warn( - f"The {key.decode(charset)!r} cookie is too large: the value was" - f" {value_size} bytes but the" + f"The '{key}' cookie is too large: the value was {value_size} bytes but the" f" header required {cookie_size - value_size} extra bytes. The final size" f" was {cookie_size} bytes but the limit is {max_size} bytes. Browsers may" - f" silently ignore cookies larger than this.", + " silently ignore cookies larger than this.", stacklevel=2, ) @@ -1289,7 +1350,7 @@ def dump_cookie( def is_byte_range_valid( - start: t.Optional[int], stop: t.Optional[int], length: t.Optional[int] + start: int | None, stop: int | None, length: int | None ) -> bool: """Checks if a given byte content range is valid for the given length. diff --git a/src/werkzeug/local.py b/src/werkzeug/local.py index 70e9bf7..fba80e9 100644 --- a/src/werkzeug/local.py +++ b/src/werkzeug/local.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import copy import math import operator @@ -18,7 +20,7 @@ T = t.TypeVar("T") F = t.TypeVar("F", bound=t.Callable[..., t.Any]) -def release_local(local: t.Union["Local", "LocalStack"]) -> None: +def release_local(local: Local | LocalStack) -> None: """Release the data for the current context in a :class:`Local` or :class:`LocalStack` without using a :class:`LocalManager`. @@ -49,9 +51,7 @@ class Local: __slots__ = ("__storage",) - def __init__( - self, context_var: t.Optional[ContextVar[t.Dict[str, t.Any]]] = None - ) -> None: + def __init__(self, context_var: ContextVar[dict[str, t.Any]] | None = None) -> None: if context_var is None: # A ContextVar not created at global scope interferes with # Python's garbage collection. However, a local only makes @@ -61,12 +61,10 @@ class Local: object.__setattr__(self, "_Local__storage", context_var) - def __iter__(self) -> t.Iterator[t.Tuple[str, t.Any]]: + def __iter__(self) -> t.Iterator[tuple[str, t.Any]]: return iter(self.__storage.get({}).items()) - def __call__( - self, name: str, *, unbound_message: t.Optional[str] = None - ) -> "LocalProxy": + def __call__(self, name: str, *, unbound_message: str | None = None) -> LocalProxy: """Create a :class:`LocalProxy` that access an attribute on this local namespace. @@ -124,7 +122,7 @@ class LocalStack(t.Generic[T]): __slots__ = ("_storage",) - def __init__(self, context_var: t.Optional[ContextVar[t.List[T]]] = None) -> None: + def __init__(self, context_var: ContextVar[list[T]] | None = None) -> None: if context_var is None: # A ContextVar not created at global scope interferes with # Python's garbage collection. However, a local only makes @@ -137,14 +135,14 @@ class LocalStack(t.Generic[T]): def __release_local__(self) -> None: self._storage.set([]) - def push(self, obj: T) -> t.List[T]: + def push(self, obj: T) -> list[T]: """Add a new item to the top of the stack.""" stack = self._storage.get([]).copy() stack.append(obj) self._storage.set(stack) return stack - def pop(self) -> t.Optional[T]: + def pop(self) -> T | None: """Remove the top item from the stack and return it. If the stack is empty, return ``None``. """ @@ -158,7 +156,7 @@ class LocalStack(t.Generic[T]): return rv @property - def top(self) -> t.Optional[T]: + def top(self) -> T | None: """The topmost item on the stack. If the stack is empty, `None` is returned. """ @@ -170,8 +168,8 @@ class LocalStack(t.Generic[T]): return stack[-1] def __call__( - self, name: t.Optional[str] = None, *, unbound_message: t.Optional[str] = None - ) -> "LocalProxy": + self, name: str | None = None, *, unbound_message: str | None = None + ) -> LocalProxy: """Create a :class:`LocalProxy` that accesses the top of this local stack. @@ -192,9 +190,8 @@ class LocalManager: :param locals: A local or list of locals to manage. - .. versionchanged:: 2.0 - ``ident_func`` is deprecated and will be removed in Werkzeug - 2.1. + .. versionchanged:: 2.1 + The ``ident_func`` was removed. .. versionchanged:: 0.7 The ``ident_func`` parameter was added. @@ -208,9 +205,7 @@ class LocalManager: def __init__( self, - locals: t.Optional[ - t.Union[Local, LocalStack, t.Iterable[t.Union[Local, LocalStack]]] - ] = None, + locals: None | (Local | LocalStack | t.Iterable[Local | LocalStack]) = None, ) -> None: if locals is None: self.locals = [] @@ -226,19 +221,19 @@ class LocalManager: for local in self.locals: release_local(local) - def make_middleware(self, app: "WSGIApplication") -> "WSGIApplication": + def make_middleware(self, app: WSGIApplication) -> WSGIApplication: """Wrap a WSGI application so that local data is released automatically after the response has been sent for a request. """ def application( - environ: "WSGIEnvironment", start_response: "StartResponse" + environ: WSGIEnvironment, start_response: StartResponse ) -> t.Iterable[bytes]: return ClosingIterator(app(environ, start_response), self.cleanup) return application - def middleware(self, func: "WSGIApplication") -> "WSGIApplication": + def middleware(self, func: WSGIApplication) -> WSGIApplication: """Like :meth:`make_middleware` but used as a decorator on the WSGI application function. @@ -274,24 +269,24 @@ class _ProxyLookup: def __init__( self, - f: t.Optional[t.Callable] = None, - fallback: t.Optional[t.Callable] = None, - class_value: t.Optional[t.Any] = None, + f: t.Callable | None = None, + fallback: t.Callable | None = None, + class_value: t.Any | None = None, is_attr: bool = False, ) -> None: - bind_f: t.Optional[t.Callable[["LocalProxy", t.Any], t.Callable]] + bind_f: t.Callable[[LocalProxy, t.Any], t.Callable] | None if hasattr(f, "__get__"): # A Python function, can be turned into a bound method. - def bind_f(instance: "LocalProxy", obj: t.Any) -> t.Callable: + def bind_f(instance: LocalProxy, obj: t.Any) -> t.Callable: return f.__get__(obj, type(obj)) # type: ignore elif f is not None: # A C function, use partial to bind the first argument. - def bind_f(instance: "LocalProxy", obj: t.Any) -> t.Callable: - return partial(f, obj) # type: ignore + def bind_f(instance: LocalProxy, obj: t.Any) -> t.Callable: + return partial(f, obj) else: # Use getattr, which will produce a bound method. @@ -302,10 +297,10 @@ class _ProxyLookup: self.class_value = class_value self.is_attr = is_attr - def __set_name__(self, owner: "LocalProxy", name: str) -> None: + def __set_name__(self, owner: LocalProxy, name: str) -> None: self.name = name - def __get__(self, instance: "LocalProxy", owner: t.Optional[type] = None) -> t.Any: + def __get__(self, instance: LocalProxy, owner: type | None = None) -> t.Any: if instance is None: if self.class_value is not None: return self.class_value @@ -313,7 +308,7 @@ class _ProxyLookup: return self try: - obj = instance._get_current_object() # type: ignore[misc] + obj = instance._get_current_object() except RuntimeError: if self.fallback is None: raise @@ -335,7 +330,7 @@ class _ProxyLookup: def __repr__(self) -> str: return f"proxy {self.name}" - def __call__(self, instance: "LocalProxy", *args: t.Any, **kwargs: t.Any) -> t.Any: + def __call__(self, instance: LocalProxy, *args: t.Any, **kwargs: t.Any) -> t.Any: """Support calling unbound methods from the class. For example, this happens with ``copy.copy``, which does ``type(x).__copy__(x)``. ``type(x)`` can't be proxied, so it @@ -352,12 +347,12 @@ class _ProxyIOp(_ProxyLookup): __slots__ = () def __init__( - self, f: t.Optional[t.Callable] = None, fallback: t.Optional[t.Callable] = None + self, f: t.Callable | None = None, fallback: t.Callable | None = None ) -> None: super().__init__(f, fallback) - def bind_f(instance: "LocalProxy", obj: t.Any) -> t.Callable: - def i_op(self: t.Any, other: t.Any) -> "LocalProxy": + def bind_f(instance: LocalProxy, obj: t.Any) -> t.Callable: + def i_op(self: t.Any, other: t.Any) -> LocalProxy: f(self, other) # type: ignore return instance @@ -471,10 +466,10 @@ class LocalProxy(t.Generic[T]): def __init__( self, - local: t.Union[ContextVar[T], Local, LocalStack[T], t.Callable[[], T]], - name: t.Optional[str] = None, + local: ContextVar[T] | Local | LocalStack[T] | t.Callable[[], T], + name: str | None = None, *, - unbound_message: t.Optional[str] = None, + unbound_message: str | None = None, ) -> None: if name is None: get_name = _identity @@ -497,7 +492,7 @@ class LocalProxy(t.Generic[T]): elif isinstance(local, LocalStack): def _get_current_object() -> T: - obj = local.top # type: ignore[union-attr] + obj = local.top if obj is None: raise RuntimeError(unbound_message) @@ -508,7 +503,7 @@ class LocalProxy(t.Generic[T]): def _get_current_object() -> T: try: - obj = local.get() # type: ignore[union-attr] + obj = local.get() except LookupError: raise RuntimeError(unbound_message) from None @@ -517,7 +512,7 @@ class LocalProxy(t.Generic[T]): elif callable(local): def _get_current_object() -> T: - return get_name(local()) # type: ignore + return get_name(local()) else: raise TypeError(f"Don't know how to proxy '{type(local)}'.") diff --git a/src/werkzeug/middleware/__init__.py b/src/werkzeug/middleware/__init__.py index 6ddcf7f..e69de29 100644 --- a/src/werkzeug/middleware/__init__.py +++ b/src/werkzeug/middleware/__init__.py @@ -1,22 +0,0 @@ -""" -Middleware -========== - -A WSGI middleware is a WSGI application that wraps another application -in order to observe or change its behavior. Werkzeug provides some -middleware for common use cases. - -.. toctree:: - :maxdepth: 1 - - proxy_fix - shared_data - dispatcher - http_proxy - lint - profiler - -The :doc:`interactive debugger ` is also a middleware that can -be applied manually, although it is typically used automatically with -the :doc:`development server `. -""" diff --git a/src/werkzeug/middleware/dispatcher.py b/src/werkzeug/middleware/dispatcher.py index ace1c75..559fea5 100644 --- a/src/werkzeug/middleware/dispatcher.py +++ b/src/werkzeug/middleware/dispatcher.py @@ -30,6 +30,8 @@ and the static files would be served directly by the HTTP server. :copyright: 2007 Pallets :license: BSD-3-Clause """ +from __future__ import annotations + import typing as t if t.TYPE_CHECKING: @@ -50,14 +52,14 @@ class DispatcherMiddleware: def __init__( self, - app: "WSGIApplication", - mounts: t.Optional[t.Dict[str, "WSGIApplication"]] = None, + app: WSGIApplication, + mounts: dict[str, WSGIApplication] | None = None, ) -> None: self.app = app self.mounts = mounts or {} def __call__( - self, environ: "WSGIEnvironment", start_response: "StartResponse" + self, environ: WSGIEnvironment, start_response: StartResponse ) -> t.Iterable[bytes]: script = environ.get("PATH_INFO", "") path_info = "" diff --git a/src/werkzeug/middleware/http_proxy.py b/src/werkzeug/middleware/http_proxy.py index 1cde458..59ba9b3 100644 --- a/src/werkzeug/middleware/http_proxy.py +++ b/src/werkzeug/middleware/http_proxy.py @@ -7,13 +7,15 @@ Basic HTTP Proxy :copyright: 2007 Pallets :license: BSD-3-Clause """ +from __future__ import annotations + import typing as t from http import client +from urllib.parse import quote +from urllib.parse import urlsplit from ..datastructures import EnvironHeaders from ..http import is_hop_by_hop_header -from ..urls import url_parse -from ..urls import url_quote from ..wsgi import get_input_stream if t.TYPE_CHECKING: @@ -78,12 +80,12 @@ class ProxyMiddleware: def __init__( self, - app: "WSGIApplication", - targets: t.Mapping[str, t.Dict[str, t.Any]], + app: WSGIApplication, + targets: t.Mapping[str, dict[str, t.Any]], chunk_size: int = 2 << 13, timeout: int = 10, ) -> None: - def _set_defaults(opts: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]: + def _set_defaults(opts: dict[str, t.Any]) -> dict[str, t.Any]: opts.setdefault("remove_prefix", False) opts.setdefault("host", "") opts.setdefault("headers", {}) @@ -98,13 +100,14 @@ class ProxyMiddleware: self.timeout = timeout def proxy_to( - self, opts: t.Dict[str, t.Any], path: str, prefix: str - ) -> "WSGIApplication": - target = url_parse(opts["target"]) - host = t.cast(str, target.ascii_host) + self, opts: dict[str, t.Any], path: str, prefix: str + ) -> WSGIApplication: + target = urlsplit(opts["target"]) + # socket can handle unicode host, but header must be ascii + host = target.hostname.encode("idna").decode("ascii") def application( - environ: "WSGIEnvironment", start_response: "StartResponse" + environ: WSGIEnvironment, start_response: StartResponse ) -> t.Iterable[bytes]: headers = list(EnvironHeaders(environ).items()) headers[:] = [ @@ -157,7 +160,9 @@ class ProxyMiddleware: ) con.connect() - remote_url = url_quote(remote_path) + # safe = https://url.spec.whatwg.org/#url-path-segment-string + # as well as percent for things that are already quoted + remote_url = quote(remote_path, safe="!$&'()*+,/:;=@%") querystring = environ["QUERY_STRING"] if querystring: @@ -217,7 +222,7 @@ class ProxyMiddleware: return application def __call__( - self, environ: "WSGIEnvironment", start_response: "StartResponse" + self, environ: WSGIEnvironment, start_response: StartResponse ) -> t.Iterable[bytes]: path = environ["PATH_INFO"] app = self.app diff --git a/src/werkzeug/middleware/lint.py b/src/werkzeug/middleware/lint.py index 6b54630..4629599 100644 --- a/src/werkzeug/middleware/lint.py +++ b/src/werkzeug/middleware/lint.py @@ -12,6 +12,8 @@ common HTTP errors such as non-empty responses for 304 status codes. :copyright: 2007 Pallets :license: BSD-3-Clause """ +from __future__ import annotations + import typing as t from types import TracebackType from urllib.parse import urlparse @@ -117,7 +119,7 @@ class ErrorStream: class GuardedWrite: - def __init__(self, write: t.Callable[[bytes], object], chunks: t.List[int]) -> None: + def __init__(self, write: t.Callable[[bytes], object], chunks: list[int]) -> None: self._write = write self._chunks = chunks @@ -131,8 +133,8 @@ class GuardedIterator: def __init__( self, iterator: t.Iterable[bytes], - headers_set: t.Tuple[int, Headers], - chunks: t.List[int], + headers_set: tuple[int, Headers], + chunks: list[int], ) -> None: self._iterator = iterator self._next = iter(iterator).__next__ @@ -140,7 +142,7 @@ class GuardedIterator: self.headers_set = headers_set self.chunks = chunks - def __iter__(self) -> "GuardedIterator": + def __iter__(self) -> GuardedIterator: return self def __next__(self) -> bytes: @@ -164,7 +166,7 @@ class GuardedIterator: self.closed = True if hasattr(self._iterator, "close"): - self._iterator.close() # type: ignore + self._iterator.close() if self.headers_set: status_code, headers = self.headers_set @@ -230,10 +232,10 @@ class LintMiddleware: app = LintMiddleware(app) """ - def __init__(self, app: "WSGIApplication") -> None: + def __init__(self, app: WSGIApplication) -> None: self.app = app - def check_environ(self, environ: "WSGIEnvironment") -> None: + def check_environ(self, environ: WSGIEnvironment) -> None: if type(environ) is not dict: warn( "WSGI environment is not a standard Python dict.", @@ -280,11 +282,9 @@ class LintMiddleware: def check_start_response( self, status: str, - headers: t.List[t.Tuple[str, str]], - exc_info: t.Optional[ - t.Tuple[t.Type[BaseException], BaseException, TracebackType] - ], - ) -> t.Tuple[int, Headers]: + headers: list[tuple[str, str]], + exc_info: None | (tuple[type[BaseException], BaseException, TracebackType]), + ) -> tuple[int, Headers]: check_type("status", status, str) status_code_str = status.split(None, 1)[0] @@ -359,9 +359,9 @@ class LintMiddleware: ) def check_iterator(self, app_iter: t.Iterable[bytes]) -> None: - if isinstance(app_iter, bytes): + if isinstance(app_iter, str): warn( - "The application returned a bytestring. The response will send one" + "The application returned a string. The response will send one" " character at a time to the client, which will kill performance." " Return a list or iterable instead.", WSGIWarning, @@ -377,8 +377,8 @@ class LintMiddleware: "A WSGI app does not take keyword arguments.", WSGIWarning, stacklevel=2 ) - environ: "WSGIEnvironment" = args[0] - start_response: "StartResponse" = args[1] + environ: WSGIEnvironment = args[0] + start_response: StartResponse = args[1] self.check_environ(environ) environ["wsgi.input"] = InputStream(environ["wsgi.input"]) @@ -388,8 +388,8 @@ class LintMiddleware: # iterate to the end and we can check the content length. environ["wsgi.file_wrapper"] = FileWrapper - headers_set: t.List[t.Any] = [] - chunks: t.List[int] = [] + headers_set: list[t.Any] = [] + chunks: list[int] = [] def checking_start_response( *args: t.Any, **kwargs: t.Any @@ -405,10 +405,10 @@ class LintMiddleware: warn("'start_response' does not take keyword arguments.", WSGIWarning) status: str = args[0] - headers: t.List[t.Tuple[str, str]] = args[1] - exc_info: t.Optional[ - t.Tuple[t.Type[BaseException], BaseException, TracebackType] - ] = (args[2] if len(args) == 3 else None) + headers: list[tuple[str, str]] = args[1] + exc_info: None | ( + tuple[type[BaseException], BaseException, TracebackType] + ) = (args[2] if len(args) == 3 else None) headers_set[:] = self.check_start_response(status, headers, exc_info) return GuardedWrite(start_response(status, headers, exc_info), chunks) diff --git a/src/werkzeug/middleware/profiler.py b/src/werkzeug/middleware/profiler.py index 200dae0..1120c83 100644 --- a/src/werkzeug/middleware/profiler.py +++ b/src/werkzeug/middleware/profiler.py @@ -11,6 +11,8 @@ that may be slowing down your application. :copyright: 2007 Pallets :license: BSD-3-Clause """ +from __future__ import annotations + import os.path import sys import time @@ -42,11 +44,16 @@ class ProfilerMiddleware: - ``{method}`` - The request method; GET, POST, etc. - ``{path}`` - The request path or 'root' should one not exist. - - ``{elapsed}`` - The elapsed time of the request. + - ``{elapsed}`` - The elapsed time of the request in milliseconds. - ``{time}`` - The time of the request. - If it is a callable, it will be called with the WSGI ``environ`` - dict and should return a filename. + If it is a callable, it will be called with the WSGI ``environ`` and + be expected to return a filename string. The ``environ`` dictionary + will also have the ``"werkzeug.profiler"`` key populated with a + dictionary containing the following fields (more may be added in the + future): + - ``{elapsed}`` - The elapsed time of the request in milliseconds. + - ``{time}`` - The time of the request. :param app: The WSGI application to wrap. :param stream: Write stats to this stream. Disable with ``None``. @@ -63,6 +70,10 @@ class ProfilerMiddleware: from werkzeug.middleware.profiler import ProfilerMiddleware app = ProfilerMiddleware(app) + .. versionchanged:: 3.0 + Added the ``"werkzeug.profiler"`` key to the ``filename_format(environ)`` + parameter with the ``elapsed`` and ``time`` fields. + .. versionchanged:: 0.15 Stats are written even if ``profile_dir`` is given, and can be disable by passing ``stream=None``. @@ -76,11 +87,11 @@ class ProfilerMiddleware: def __init__( self, - app: "WSGIApplication", - stream: t.IO[str] = sys.stdout, + app: WSGIApplication, + stream: t.IO[str] | None = sys.stdout, sort_by: t.Iterable[str] = ("time", "calls"), - restrictions: t.Iterable[t.Union[str, int, float]] = (), - profile_dir: t.Optional[str] = None, + restrictions: t.Iterable[str | int | float] = (), + profile_dir: str | None = None, filename_format: str = "{method}.{path}.{elapsed:.0f}ms.{time:.0f}.prof", ) -> None: self._app = app @@ -91,9 +102,9 @@ class ProfilerMiddleware: self._filename_format = filename_format def __call__( - self, environ: "WSGIEnvironment", start_response: "StartResponse" + self, environ: WSGIEnvironment, start_response: StartResponse ) -> t.Iterable[bytes]: - response_body: t.List[bytes] = [] + response_body: list[bytes] = [] def catching_start_response(status, headers, exc_info=None): # type: ignore start_response(status, headers, exc_info) @@ -106,7 +117,7 @@ class ProfilerMiddleware: response_body.extend(app_iter) if hasattr(app_iter, "close"): - app_iter.close() # type: ignore + app_iter.close() profile = Profile() start = time.time() @@ -116,6 +127,10 @@ class ProfilerMiddleware: if self._profile_dir is not None: if callable(self._filename_format): + environ["werkzeug.profiler"] = { + "elapsed": elapsed * 1000.0, + "time": time.time(), + } filename = self._filename_format(environ) else: filename = self._filename_format.format( diff --git a/src/werkzeug/middleware/proxy_fix.py b/src/werkzeug/middleware/proxy_fix.py index 4cef7cc..8dfbb36 100644 --- a/src/werkzeug/middleware/proxy_fix.py +++ b/src/werkzeug/middleware/proxy_fix.py @@ -21,6 +21,8 @@ setting each header so the middleware knows what to trust. :copyright: 2007 Pallets :license: BSD-3-Clause """ +from __future__ import annotations + import typing as t from ..http import parse_list_header @@ -64,23 +66,16 @@ class ProxyFix: app = ProxyFix(app, x_for=1, x_host=1) .. versionchanged:: 1.0 - Deprecated code has been removed: - - * The ``num_proxies`` argument and attribute. - * The ``get_remote_addr`` method. - * The environ keys ``orig_remote_addr``, - ``orig_wsgi_url_scheme``, and ``orig_http_host``. + The ``num_proxies`` argument and attribute; the ``get_remote_addr`` method; and + the environ keys ``orig_remote_addr``, ``orig_wsgi_url_scheme``, and + ``orig_http_host`` were removed. .. versionchanged:: 0.15 - All headers support multiple values. The ``num_proxies`` - argument is deprecated. Each header is configured with a - separate number of trusted proxies. + All headers support multiple values. Each header is configured with a separate + number of trusted proxies. .. versionchanged:: 0.15 - Original WSGI environ values are stored in the - ``werkzeug.proxy_fix.orig`` dict. ``orig_remote_addr``, - ``orig_wsgi_url_scheme``, and ``orig_http_host`` are deprecated - and will be removed in 1.0. + Original WSGI environ values are stored in the ``werkzeug.proxy_fix.orig`` dict. .. versionchanged:: 0.15 Support ``X-Forwarded-Port`` and ``X-Forwarded-Prefix``. @@ -92,7 +87,7 @@ class ProxyFix: def __init__( self, - app: "WSGIApplication", + app: WSGIApplication, x_for: int = 1, x_proto: int = 1, x_host: int = 0, @@ -106,7 +101,7 @@ class ProxyFix: self.x_port = x_port self.x_prefix = x_prefix - def _get_real_value(self, trusted: int, value: t.Optional[str]) -> t.Optional[str]: + def _get_real_value(self, trusted: int, value: str | None) -> str | None: """Get the real value from a list header based on the configured number of trusted proxies. @@ -128,7 +123,7 @@ class ProxyFix: return None def __call__( - self, environ: "WSGIEnvironment", start_response: "StartResponse" + self, environ: WSGIEnvironment, start_response: StartResponse ) -> t.Iterable[bytes]: """Modify the WSGI environ based on the various ``Forwarded`` headers before calling the wrapped application. Store the diff --git a/src/werkzeug/middleware/shared_data.py b/src/werkzeug/middleware/shared_data.py index 2ec396c..e3ec7ca 100644 --- a/src/werkzeug/middleware/shared_data.py +++ b/src/werkzeug/middleware/shared_data.py @@ -8,9 +8,11 @@ Serve Shared Static Files :copyright: 2007 Pallets :license: BSD-3-Clause """ +from __future__ import annotations + +import importlib.util import mimetypes import os -import pkgutil import posixpath import typing as t from datetime import datetime @@ -99,18 +101,18 @@ class SharedDataMiddleware: def __init__( self, - app: "WSGIApplication", - exports: t.Union[ - t.Dict[str, t.Union[str, t.Tuple[str, str]]], - t.Iterable[t.Tuple[str, t.Union[str, t.Tuple[str, str]]]], - ], + app: WSGIApplication, + exports: ( + dict[str, str | tuple[str, str]] + | t.Iterable[tuple[str, str | tuple[str, str]]] + ), disallow: None = None, cache: bool = True, cache_timeout: int = 60 * 60 * 12, fallback_mimetype: str = "application/octet-stream", ) -> None: self.app = app - self.exports: t.List[t.Tuple[str, _TLoader]] = [] + self.exports: list[tuple[str, _TLoader]] = [] self.cache = cache self.cache_timeout = cache_timeout @@ -156,12 +158,12 @@ class SharedDataMiddleware: def get_package_loader(self, package: str, package_path: str) -> _TLoader: load_time = datetime.now(timezone.utc) - provider = pkgutil.get_loader(package) - reader = provider.get_resource_reader(package) # type: ignore + spec = importlib.util.find_spec(package) + reader = spec.loader.get_resource_reader(package) # type: ignore[union-attr] def loader( - path: t.Optional[str], - ) -> t.Tuple[t.Optional[str], t.Optional[_TOpener]]: + path: str | None, + ) -> tuple[str | None, _TOpener | None]: if path is None: return None, None @@ -198,8 +200,8 @@ class SharedDataMiddleware: def get_directory_loader(self, directory: str) -> _TLoader: def loader( - path: t.Optional[str], - ) -> t.Tuple[t.Optional[str], t.Optional[_TOpener]]: + path: str | None, + ) -> tuple[str | None, _TOpener | None]: if path is not None: path = safe_join(directory, path) @@ -222,7 +224,7 @@ class SharedDataMiddleware: return f"wzsdm-{timestamp}-{file_size}-{checksum}" def __call__( - self, environ: "WSGIEnvironment", start_response: "StartResponse" + self, environ: WSGIEnvironment, start_response: StartResponse ) -> t.Iterable[bytes]: path = get_path_info(environ) file_loader = None diff --git a/src/werkzeug/routing/converters.py b/src/werkzeug/routing/converters.py index bbad29d..ce01dd1 100644 --- a/src/werkzeug/routing/converters.py +++ b/src/werkzeug/routing/converters.py @@ -1,8 +1,9 @@ +from __future__ import annotations + import re import typing as t import uuid - -from ..urls import _fast_url_quote +from urllib.parse import quote if t.TYPE_CHECKING: from .map import Map @@ -15,22 +16,33 @@ class ValidationError(ValueError): class BaseConverter: - """Base class for all converters.""" + """Base class for all converters. + + .. versionchanged:: 2.3 + ``part_isolating`` defaults to ``False`` if ``regex`` contains a ``/``. + """ regex = "[^/]+" weight = 100 part_isolating = True - def __init__(self, map: "Map", *args: t.Any, **kwargs: t.Any) -> None: + def __init_subclass__(cls, **kwargs: t.Any) -> None: + super().__init_subclass__(**kwargs) + + # If the converter isn't inheriting its regex, disable part_isolating by default + # if the regex contains a / character. + if "regex" in cls.__dict__ and "part_isolating" not in cls.__dict__: + cls.part_isolating = "/" not in cls.regex + + def __init__(self, map: Map, *args: t.Any, **kwargs: t.Any) -> None: self.map = map def to_python(self, value: str) -> t.Any: return value def to_url(self, value: t.Any) -> str: - if isinstance(value, (bytes, bytearray)): - return _fast_url_quote(value) - return _fast_url_quote(str(value).encode(self.map.charset)) + # safe = https://url.spec.whatwg.org/#url-path-segment-string + return quote(str(value), safe="!$&'()*+,/:;=@") class UnicodeConverter(BaseConverter): @@ -51,14 +63,12 @@ class UnicodeConverter(BaseConverter): :param length: the exact length of the string. """ - part_isolating = True - def __init__( self, - map: "Map", + map: Map, minlength: int = 1, - maxlength: t.Optional[int] = None, - length: t.Optional[int] = None, + maxlength: int | None = None, + length: int | None = None, ) -> None: super().__init__(map) if length is not None: @@ -86,9 +96,7 @@ class AnyConverter(BaseConverter): Value is validated when building a URL. """ - part_isolating = True - - def __init__(self, map: "Map", *items: str) -> None: + def __init__(self, map: Map, *items: str) -> None: super().__init__(map) self.items = set(items) self.regex = f"(?:{'|'.join([re.escape(x) for x in items])})" @@ -111,9 +119,9 @@ class PathConverter(BaseConverter): :param map: the :class:`Map`. """ + part_isolating = False regex = "[^/].*?" weight = 200 - part_isolating = False class NumberConverter(BaseConverter): @@ -124,14 +132,13 @@ class NumberConverter(BaseConverter): weight = 50 num_convert: t.Callable = int - part_isolating = True def __init__( self, - map: "Map", + map: Map, fixed_digits: int = 0, - min: t.Optional[int] = None, - max: t.Optional[int] = None, + min: int | None = None, + max: int | None = None, signed: bool = False, ) -> None: if signed: @@ -186,7 +193,6 @@ class IntegerConverter(NumberConverter): """ regex = r"\d+" - part_isolating = True class FloatConverter(NumberConverter): @@ -210,13 +216,12 @@ class FloatConverter(NumberConverter): regex = r"\d+\.\d+" num_convert = float - part_isolating = True def __init__( self, - map: "Map", - min: t.Optional[float] = None, - max: t.Optional[float] = None, + map: Map, + min: float | None = None, + max: float | None = None, signed: bool = False, ) -> None: super().__init__(map, min=min, max=max, signed=signed) # type: ignore @@ -236,7 +241,6 @@ class UUIDConverter(BaseConverter): r"[A-Fa-f0-9]{8}-[A-Fa-f0-9]{4}-" r"[A-Fa-f0-9]{4}-[A-Fa-f0-9]{4}-[A-Fa-f0-9]{12}" ) - part_isolating = True def to_python(self, value: str) -> uuid.UUID: return uuid.UUID(value) @@ -246,7 +250,7 @@ class UUIDConverter(BaseConverter): #: the default converter mapping for the map. -DEFAULT_CONVERTERS: t.Mapping[str, t.Type[BaseConverter]] = { +DEFAULT_CONVERTERS: t.Mapping[str, type[BaseConverter]] = { "default": UnicodeConverter, "string": UnicodeConverter, "any": AnyConverter, diff --git a/src/werkzeug/routing/exceptions.py b/src/werkzeug/routing/exceptions.py index 7cbe6e9..9d0a528 100644 --- a/src/werkzeug/routing/exceptions.py +++ b/src/werkzeug/routing/exceptions.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import difflib import typing as t @@ -9,7 +11,7 @@ from ..utils import redirect if t.TYPE_CHECKING: from _typeshed.wsgi import WSGIEnvironment from .map import MapAdapter - from .rules import Rule # noqa: F401 + from .rules import Rule from ..wrappers.request import Request from ..wrappers.response import Response @@ -37,9 +39,9 @@ class RequestRedirect(HTTPException, RoutingException): def get_response( self, - environ: t.Optional[t.Union["WSGIEnvironment", "Request"]] = None, - scope: t.Optional[dict] = None, - ) -> "Response": + environ: WSGIEnvironment | Request | None = None, + scope: dict | None = None, + ) -> Response: return redirect(self.new_url, self.code) @@ -71,8 +73,8 @@ class BuildError(RoutingException, LookupError): self, endpoint: str, values: t.Mapping[str, t.Any], - method: t.Optional[str], - adapter: t.Optional["MapAdapter"] = None, + method: str | None, + adapter: MapAdapter | None = None, ) -> None: super().__init__(endpoint, values, method) self.endpoint = endpoint @@ -81,11 +83,11 @@ class BuildError(RoutingException, LookupError): self.adapter = adapter @cached_property - def suggested(self) -> t.Optional["Rule"]: + def suggested(self) -> Rule | None: return self.closest_rule(self.adapter) - def closest_rule(self, adapter: t.Optional["MapAdapter"]) -> t.Optional["Rule"]: - def _score_rule(rule: "Rule") -> float: + def closest_rule(self, adapter: MapAdapter | None) -> Rule | None: + def _score_rule(rule: Rule) -> float: return sum( [ 0.98 @@ -141,6 +143,6 @@ class WebsocketMismatch(BadRequest): class NoMatch(Exception): __slots__ = ("have_match_for", "websocket_mismatch") - def __init__(self, have_match_for: t.Set[str], websocket_mismatch: bool) -> None: + def __init__(self, have_match_for: set[str], websocket_mismatch: bool) -> None: self.have_match_for = have_match_for self.websocket_mismatch = websocket_mismatch diff --git a/src/werkzeug/routing/map.py b/src/werkzeug/routing/map.py index daf94b6..76bbe2f 100644 --- a/src/werkzeug/routing/map.py +++ b/src/werkzeug/routing/map.py @@ -1,12 +1,14 @@ -import posixpath +from __future__ import annotations + import typing as t import warnings from pprint import pformat from threading import Lock +from urllib.parse import quote +from urllib.parse import urljoin +from urllib.parse import urlunsplit -from .._internal import _encode_idna from .._internal import _get_environ -from .._internal import _to_str from .._internal import _wsgi_decoding_dance from ..datastructures import ImmutableDict from ..datastructures import MultiDict @@ -14,9 +16,7 @@ from ..exceptions import BadHost from ..exceptions import HTTPException from ..exceptions import MethodNotAllowed from ..exceptions import NotFound -from ..urls import url_encode -from ..urls import url_join -from ..urls import url_quote +from ..urls import _urlencode from ..wsgi import get_host from .converters import DEFAULT_CONVERTERS from .exceptions import BuildError @@ -30,7 +30,6 @@ from .rules import _simple_rule_re from .rules import Rule if t.TYPE_CHECKING: - import typing_extensions as te from _typeshed.wsgi import WSGIApplication from _typeshed.wsgi import WSGIEnvironment from .converters import BaseConverter @@ -48,7 +47,6 @@ class Map: :param rules: sequence of url rules for this map. :param default_subdomain: The default subdomain for rules without a subdomain defined. - :param charset: charset of the url. defaults to ``"utf-8"`` :param strict_slashes: If a rule ends with a slash but the matched URL does not, redirect to the URL with a trailing slash. :param merge_slashes: Merge consecutive slashes when matching or @@ -63,24 +61,25 @@ class Map: :param sort_parameters: If set to `True` the url parameters are sorted. See `url_encode` for more details. :param sort_key: The sort key function for `url_encode`. - :param encoding_errors: the error method to use for decoding :param host_matching: if set to `True` it enables the host matching feature and disables the subdomain one. If enabled the `host` parameter to rules is used instead of the `subdomain` one. - .. versionchanged:: 1.0 - If ``url_scheme`` is ``ws`` or ``wss``, only WebSocket rules - will match. + .. versionchanged:: 3.0 + The ``charset`` and ``encoding_errors`` parameters were removed. .. versionchanged:: 1.0 - Added ``merge_slashes``. + If ``url_scheme`` is ``ws`` or ``wss``, only WebSocket rules will match. + + .. versionchanged:: 1.0 + The ``merge_slashes`` parameter was added. .. versionchanged:: 0.7 - Added ``encoding_errors`` and ``host_matching``. + The ``encoding_errors`` and ``host_matching`` parameters were added. .. versionchanged:: 0.5 - Added ``sort_parameters`` and ``sort_key``. + The ``sort_parameters`` and ``sort_key`` paramters were added. """ #: A dict of default converters to be used. @@ -93,26 +92,22 @@ class Map: def __init__( self, - rules: t.Optional[t.Iterable["RuleFactory"]] = None, + rules: t.Iterable[RuleFactory] | None = None, default_subdomain: str = "", - charset: str = "utf-8", strict_slashes: bool = True, merge_slashes: bool = True, redirect_defaults: bool = True, - converters: t.Optional[t.Mapping[str, t.Type["BaseConverter"]]] = None, + converters: t.Mapping[str, type[BaseConverter]] | None = None, sort_parameters: bool = False, - sort_key: t.Optional[t.Callable[[t.Any], t.Any]] = None, - encoding_errors: str = "replace", + sort_key: t.Callable[[t.Any], t.Any] | None = None, host_matching: bool = False, ) -> None: self._matcher = StateMachineMatcher(merge_slashes) - self._rules_by_endpoint: t.Dict[str, t.List[Rule]] = {} + self._rules_by_endpoint: dict[str, list[Rule]] = {} self._remap = True self._remap_lock = self.lock_class() self.default_subdomain = default_subdomain - self.charset = charset - self.encoding_errors = encoding_errors self.strict_slashes = strict_slashes self.merge_slashes = merge_slashes self.redirect_defaults = redirect_defaults @@ -149,10 +144,10 @@ class Map: return False @property - def _rules(self) -> t.List[Rule]: + def _rules(self) -> list[Rule]: return [rule for rules in self._rules_by_endpoint.values() for rule in rules] - def iter_rules(self, endpoint: t.Optional[str] = None) -> t.Iterator[Rule]: + def iter_rules(self, endpoint: str | None = None) -> t.Iterator[Rule]: """Iterate over all rules or the rules of an endpoint. :param endpoint: if provided only the rules for that endpoint @@ -164,7 +159,7 @@ class Map: return iter(self._rules_by_endpoint[endpoint]) return iter(self._rules) - def add(self, rulefactory: "RuleFactory") -> None: + def add(self, rulefactory: RuleFactory) -> None: """Add a new rule or factory to the map and bind it. Requires that the rule is not bound to another map. @@ -180,13 +175,13 @@ class Map: def bind( self, server_name: str, - script_name: t.Optional[str] = None, - subdomain: t.Optional[str] = None, + script_name: str | None = None, + subdomain: str | None = None, url_scheme: str = "http", default_method: str = "GET", - path_info: t.Optional[str] = None, - query_args: t.Optional[t.Union[t.Mapping[str, t.Any], str]] = None, - ) -> "MapAdapter": + path_info: str | None = None, + query_args: t.Mapping[str, t.Any] | str | None = None, + ) -> MapAdapter: """Return a new :class:`MapAdapter` with the details specified to the call. Note that `script_name` will default to ``'/'`` if not further specified or `None`. The `server_name` at least is a requirement @@ -227,14 +222,17 @@ class Map: if path_info is None: path_info = "/" + # Port isn't part of IDNA, and might push a name over the 63 octet limit. + server_name, port_sep, port = server_name.partition(":") + try: - server_name = _encode_idna(server_name) # type: ignore + server_name = server_name.encode("idna").decode("ascii") except UnicodeError as e: raise BadHost() from e return MapAdapter( self, - server_name, + f"{server_name}{port_sep}{port}", script_name, subdomain, url_scheme, @@ -245,10 +243,10 @@ class Map: def bind_to_environ( self, - environ: t.Union["WSGIEnvironment", "Request"], - server_name: t.Optional[str] = None, - subdomain: t.Optional[str] = None, - ) -> "MapAdapter": + environ: WSGIEnvironment | Request, + server_name: str | None = None, + subdomain: str | None = None, + ) -> MapAdapter: """Like :meth:`bind` but you can pass it an WSGI environment and it will fetch the information from that dictionary. Note that because of limitations in the protocol there is no way to get the current @@ -332,10 +330,10 @@ class Map: else: subdomain = ".".join(filter(None, cur_server_name[:offset])) - def _get_wsgi_string(name: str) -> t.Optional[str]: + def _get_wsgi_string(name: str) -> str | None: val = env.get(name) if val is not None: - return _wsgi_decoding_dance(val, self.charset) + return _wsgi_decoding_dance(val) return None script_name = _get_wsgi_string("SCRIPT_NAME") @@ -384,32 +382,33 @@ class MapAdapter: map: Map, server_name: str, script_name: str, - subdomain: t.Optional[str], + subdomain: str | None, url_scheme: str, path_info: str, default_method: str, - query_args: t.Optional[t.Union[t.Mapping[str, t.Any], str]] = None, + query_args: t.Mapping[str, t.Any] | str | None = None, ): self.map = map - self.server_name = _to_str(server_name) - script_name = _to_str(script_name) + self.server_name = server_name + if not script_name.endswith("/"): script_name += "/" + self.script_name = script_name - self.subdomain = _to_str(subdomain) - self.url_scheme = _to_str(url_scheme) - self.path_info = _to_str(path_info) - self.default_method = _to_str(default_method) + self.subdomain = subdomain + self.url_scheme = url_scheme + self.path_info = path_info + self.default_method = default_method self.query_args = query_args self.websocket = self.url_scheme in {"ws", "wss"} def dispatch( self, - view_func: t.Callable[[str, t.Mapping[str, t.Any]], "WSGIApplication"], - path_info: t.Optional[str] = None, - method: t.Optional[str] = None, + view_func: t.Callable[[str, t.Mapping[str, t.Any]], WSGIApplication], + path_info: str | None = None, + method: str | None = None, catch_http_exceptions: bool = False, - ) -> "WSGIApplication": + ) -> WSGIApplication: """Does the complete dispatching process. `view_func` is called with the endpoint and a dict with the values for the view. It should look up the view function, call it, and return a response object @@ -466,33 +465,33 @@ class MapAdapter: @t.overload def match( # type: ignore self, - path_info: t.Optional[str] = None, - method: t.Optional[str] = None, - return_rule: "te.Literal[False]" = False, - query_args: t.Optional[t.Union[t.Mapping[str, t.Any], str]] = None, - websocket: t.Optional[bool] = None, - ) -> t.Tuple[str, t.Mapping[str, t.Any]]: + path_info: str | None = None, + method: str | None = None, + return_rule: t.Literal[False] = False, + query_args: t.Mapping[str, t.Any] | str | None = None, + websocket: bool | None = None, + ) -> tuple[str, t.Mapping[str, t.Any]]: ... @t.overload def match( self, - path_info: t.Optional[str] = None, - method: t.Optional[str] = None, - return_rule: "te.Literal[True]" = True, - query_args: t.Optional[t.Union[t.Mapping[str, t.Any], str]] = None, - websocket: t.Optional[bool] = None, - ) -> t.Tuple[Rule, t.Mapping[str, t.Any]]: + path_info: str | None = None, + method: str | None = None, + return_rule: t.Literal[True] = True, + query_args: t.Mapping[str, t.Any] | str | None = None, + websocket: bool | None = None, + ) -> tuple[Rule, t.Mapping[str, t.Any]]: ... def match( self, - path_info: t.Optional[str] = None, - method: t.Optional[str] = None, + path_info: str | None = None, + method: str | None = None, return_rule: bool = False, - query_args: t.Optional[t.Union[t.Mapping[str, t.Any], str]] = None, - websocket: t.Optional[bool] = None, - ) -> t.Tuple[t.Union[str, Rule], t.Mapping[str, t.Any]]: + query_args: t.Mapping[str, t.Any] | str | None = None, + websocket: bool | None = None, + ) -> tuple[str | Rule, t.Mapping[str, t.Any]]: """The usage is simple: you just pass the match method the current path info as well as the method (which defaults to `GET`). The following things can then happen: @@ -583,8 +582,6 @@ class MapAdapter: self.map.update() if path_info is None: path_info = self.path_info - else: - path_info = _to_str(path_info, self.map.charset) if query_args is None: query_args = self.query_args or {} method = (method or self.default_method).upper() @@ -592,17 +589,20 @@ class MapAdapter: if websocket is None: websocket = self.websocket - domain_part = self.server_name if self.map.host_matching else self.subdomain + domain_part = self.server_name + + if not self.map.host_matching and self.subdomain is not None: + domain_part = self.subdomain + path_part = f"/{path_info.lstrip('/')}" if path_info else "" try: result = self.map._matcher.match(domain_part, path_part, method, websocket) except RequestPath as e: + # safe = https://url.spec.whatwg.org/#url-path-segment-string + new_path = quote(e.path_info, safe="!$&'()*+,/:;=@") raise RequestRedirect( - self.make_redirect_url( - url_quote(e.path_info, self.map.charset, safe="/:|+"), - query_args, - ) + self.make_redirect_url(new_path, query_args) ) from None except RequestAliasRedirect as e: raise RequestRedirect( @@ -647,7 +647,7 @@ class MapAdapter: netloc = self.server_name raise RequestRedirect( - url_join( + urljoin( f"{self.url_scheme or 'http'}://{netloc}{self.script_name}", redirect_url, ) @@ -658,9 +658,7 @@ class MapAdapter: else: return rule.endpoint, rv - def test( - self, path_info: t.Optional[str] = None, method: t.Optional[str] = None - ) -> bool: + def test(self, path_info: str | None = None, method: str | None = None) -> bool: """Test if a rule would match. Works like `match` but returns `True` if the URL matches, or `False` if it does not exist. @@ -677,7 +675,7 @@ class MapAdapter: return False return True - def allowed_methods(self, path_info: t.Optional[str] = None) -> t.Iterable[str]: + def allowed_methods(self, path_info: str | None = None) -> t.Iterable[str]: """Returns the valid methods that match for a given path. .. versionadded:: 0.7 @@ -690,7 +688,7 @@ class MapAdapter: pass return [] - def get_host(self, domain_part: t.Optional[str]) -> str: + def get_host(self, domain_part: str | None) -> str: """Figures out the full host name for the given domain part. The domain part is a subdomain in case host matching is disabled or a full host name. @@ -698,12 +696,13 @@ class MapAdapter: if self.map.host_matching: if domain_part is None: return self.server_name - return _to_str(domain_part, "ascii") - subdomain = domain_part - if subdomain is None: + + return domain_part + + if domain_part is None: subdomain = self.subdomain else: - subdomain = _to_str(subdomain, "ascii") + subdomain = domain_part if subdomain: return f"{subdomain}.{self.server_name}" @@ -715,8 +714,8 @@ class MapAdapter: rule: Rule, method: str, values: t.MutableMapping[str, t.Any], - query_args: t.Union[t.Mapping[str, t.Any], str], - ) -> t.Optional[str]: + query_args: t.Mapping[str, t.Any] | str, + ) -> str | None: """A helper that returns the URL to redirect to if it finds one. This is used for default redirecting only. @@ -735,30 +734,33 @@ class MapAdapter: return self.make_redirect_url(path, query_args, domain_part=domain_part) return None - def encode_query_args(self, query_args: t.Union[t.Mapping[str, t.Any], str]) -> str: + def encode_query_args(self, query_args: t.Mapping[str, t.Any] | str) -> str: if not isinstance(query_args, str): - return url_encode(query_args, self.map.charset) + return _urlencode(query_args) return query_args def make_redirect_url( self, path_info: str, - query_args: t.Optional[t.Union[t.Mapping[str, t.Any], str]] = None, - domain_part: t.Optional[str] = None, + query_args: t.Mapping[str, t.Any] | str | None = None, + domain_part: str | None = None, ) -> str: """Creates a redirect URL. :internal: """ + if query_args is None: + query_args = self.query_args + if query_args: - suffix = f"?{self.encode_query_args(query_args)}" + query_str = self.encode_query_args(query_args) else: - suffix = "" + query_str = None scheme = self.url_scheme or "http" host = self.get_host(domain_part) - path = posixpath.join(self.script_name.strip("/"), path_info.lstrip("/")) - return f"{scheme}://{host}/{path}{suffix}" + path = "/".join((self.script_name.strip("/"), path_info.lstrip("/"))) + return urlunsplit((scheme, host, path, query_str, None)) def make_alias_redirect_url( self, @@ -766,7 +768,7 @@ class MapAdapter: endpoint: str, values: t.Mapping[str, t.Any], method: str, - query_args: t.Union[t.Mapping[str, t.Any], str], + query_args: t.Mapping[str, t.Any] | str, ) -> str: """Internally called to make an alias redirect URL.""" url = self.build( @@ -781,9 +783,9 @@ class MapAdapter: self, endpoint: str, values: t.Mapping[str, t.Any], - method: t.Optional[str], + method: str | None, append_unknown: bool, - ) -> t.Optional[t.Tuple[str, str, bool]]: + ) -> tuple[str, str, bool] | None: """Helper for :meth:`build`. Returns subdomain and path for the rule that accepts this endpoint, values and method. @@ -821,11 +823,11 @@ class MapAdapter: def build( self, endpoint: str, - values: t.Optional[t.Mapping[str, t.Any]] = None, - method: t.Optional[str] = None, + values: t.Mapping[str, t.Any] | None = None, + method: str | None = None, force_external: bool = False, append_unknown: bool = True, - url_scheme: t.Optional[str] = None, + url_scheme: str | None = None, ) -> str: """Building URLs works pretty much the other way round. Instead of `match` you call `build` and pass it the endpoint and a dict of diff --git a/src/werkzeug/routing/matcher.py b/src/werkzeug/routing/matcher.py index d22b05a..0d1210a 100644 --- a/src/werkzeug/routing/matcher.py +++ b/src/werkzeug/routing/matcher.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import re import typing as t from dataclasses import dataclass @@ -23,9 +25,9 @@ class State: possible *static* and *dynamic* transitions to the next state. """ - dynamic: t.List[t.Tuple[RulePart, "State"]] = field(default_factory=list) - rules: t.List[Rule] = field(default_factory=list) - static: t.Dict[str, "State"] = field(default_factory=dict) + dynamic: list[tuple[RulePart, State]] = field(default_factory=list) + rules: list[Rule] = field(default_factory=list) + static: dict[str, State] = field(default_factory=dict) class StateMachineMatcher: @@ -66,7 +68,7 @@ class StateMachineMatcher: def match( self, domain: str, path: str, method: str, websocket: bool - ) -> t.Tuple[Rule, t.MutableMapping[str, t.Any]]: + ) -> tuple[Rule, t.MutableMapping[str, t.Any]]: # To match to a rule we need to start at the root state and # try to follow the transitions until we find a match, or find # there is no transition to follow. @@ -75,8 +77,8 @@ class StateMachineMatcher: websocket_mismatch = False def _match( - state: State, parts: t.List[str], values: t.List[str] - ) -> t.Optional[t.Tuple[Rule, t.List[str]]]: + state: State, parts: list[str], values: list[str] + ) -> tuple[Rule, list[str]] | None: # This function is meant to be called recursively, and will attempt # to match the head part to the state's transitions. nonlocal have_match_for, websocket_mismatch @@ -127,7 +129,22 @@ class StateMachineMatcher: remaining = [] match = re.compile(test_part.content).match(target) if match is not None: - rv = _match(new_state, remaining, values + list(match.groups())) + if test_part.suffixed: + # If a part_isolating=False part has a slash suffix, remove the + # suffix from the match and check for the slash redirect next. + suffix = match.groups()[-1] + if suffix == "/": + remaining = [""] + + converter_groups = sorted( + match.groupdict().items(), key=lambda entry: entry[0] + ) + groups = [ + value + for key, value in converter_groups + if key[:11] == "__werkzeug_" + ] + rv = _match(new_state, remaining, values + groups) if rv is not None: return rv diff --git a/src/werkzeug/routing/rules.py b/src/werkzeug/routing/rules.py index a61717a..5c8184c 100644 --- a/src/werkzeug/routing/rules.py +++ b/src/werkzeug/routing/rules.py @@ -1,13 +1,15 @@ +from __future__ import annotations + import ast import re import typing as t from dataclasses import dataclass from string import Template from types import CodeType +from urllib.parse import quote -from .._internal import _to_bytes -from ..urls import url_encode -from ..urls import url_quote +from ..datastructures import iter_multi_items +from ..urls import _urlencode from .converters import ValidationError if t.TYPE_CHECKING: @@ -17,9 +19,9 @@ if t.TYPE_CHECKING: class Weighting(t.NamedTuple): number_static_weights: int - static_weights: t.List[t.Tuple[int, int]] + static_weights: list[tuple[int, int]] number_argument_weights: int - argument_weights: t.List[int] + argument_weights: list[int] @dataclass @@ -36,22 +38,23 @@ class RulePart: content: str final: bool static: bool + suffixed: bool weight: Weighting _part_re = re.compile( r""" (?: - (?P\/) # a slash + (?P/) # a slash | - (?P[^<\/]+) # static rule data + (?P[^[a-zA-Z_][a-zA-Z0-9_]*) # converter name (?:\((?P.*?)\))? # converter arguments - \: # variable delimiter + : # variable delimiter )? (?P[a-zA-Z_][a-zA-Z0-9_]*) # variable name > @@ -92,7 +95,7 @@ def _find(value: str, target: str, pos: int) -> int: return len(value) -def _pythonize(value: str) -> t.Union[None, bool, int, float, str]: +def _pythonize(value: str) -> None | bool | int | float | str: if value in _PYTHON_CONSTANTS: return _PYTHON_CONSTANTS[value] for convert in int, float: @@ -105,7 +108,7 @@ def _pythonize(value: str) -> t.Union[None, bool, int, float, str]: return str(value) -def parse_converter_args(argstr: str) -> t.Tuple[t.Tuple, t.Dict[str, t.Any]]: +def parse_converter_args(argstr: str) -> tuple[t.Tuple, dict[str, t.Any]]: argstr += "," args = [] kwargs = {} @@ -130,7 +133,7 @@ class RuleFactory: be added by subclassing `RuleFactory` and overriding `get_rules`. """ - def get_rules(self, map: "Map") -> t.Iterable["Rule"]: + def get_rules(self, map: Map) -> t.Iterable[Rule]: """Subclasses of `RuleFactory` have to override this method and return an iterable of rules.""" raise NotImplementedError() @@ -159,7 +162,7 @@ class Subdomain(RuleFactory): self.subdomain = subdomain self.rules = rules - def get_rules(self, map: "Map") -> t.Iterator["Rule"]: + def get_rules(self, map: Map) -> t.Iterator[Rule]: for rulefactory in self.rules: for rule in rulefactory.get_rules(map): rule = rule.empty() @@ -185,7 +188,7 @@ class Submount(RuleFactory): self.path = path.rstrip("/") self.rules = rules - def get_rules(self, map: "Map") -> t.Iterator["Rule"]: + def get_rules(self, map: Map) -> t.Iterator[Rule]: for rulefactory in self.rules: for rule in rulefactory.get_rules(map): rule = rule.empty() @@ -210,7 +213,7 @@ class EndpointPrefix(RuleFactory): self.prefix = prefix self.rules = rules - def get_rules(self, map: "Map") -> t.Iterator["Rule"]: + def get_rules(self, map: Map) -> t.Iterator[Rule]: for rulefactory in self.rules: for rule in rulefactory.get_rules(map): rule = rule.empty() @@ -237,10 +240,10 @@ class RuleTemplate: replace the placeholders in all the string parameters. """ - def __init__(self, rules: t.Iterable["Rule"]) -> None: + def __init__(self, rules: t.Iterable[Rule]) -> None: self.rules = list(rules) - def __call__(self, *args: t.Any, **kwargs: t.Any) -> "RuleTemplateFactory": + def __call__(self, *args: t.Any, **kwargs: t.Any) -> RuleTemplateFactory: return RuleTemplateFactory(self.rules, dict(*args, **kwargs)) @@ -252,12 +255,12 @@ class RuleTemplateFactory(RuleFactory): """ def __init__( - self, rules: t.Iterable[RuleFactory], context: t.Dict[str, t.Any] + self, rules: t.Iterable[RuleFactory], context: dict[str, t.Any] ) -> None: self.rules = rules self.context = context - def get_rules(self, map: "Map") -> t.Iterator["Rule"]: + def get_rules(self, map: Map) -> t.Iterator[Rule]: for rulefactory in self.rules: for rule in rulefactory.get_rules(map): new_defaults = subdomain = None @@ -438,25 +441,26 @@ class Rule(RuleFactory): def __init__( self, string: str, - defaults: t.Optional[t.Mapping[str, t.Any]] = None, - subdomain: t.Optional[str] = None, - methods: t.Optional[t.Iterable[str]] = None, + defaults: t.Mapping[str, t.Any] | None = None, + subdomain: str | None = None, + methods: t.Iterable[str] | None = None, build_only: bool = False, - endpoint: t.Optional[str] = None, - strict_slashes: t.Optional[bool] = None, - merge_slashes: t.Optional[bool] = None, - redirect_to: t.Optional[t.Union[str, t.Callable[..., str]]] = None, + endpoint: str | None = None, + strict_slashes: bool | None = None, + merge_slashes: bool | None = None, + redirect_to: str | t.Callable[..., str] | None = None, alias: bool = False, - host: t.Optional[str] = None, + host: str | None = None, websocket: bool = False, ) -> None: if not string.startswith("/"): - raise ValueError("urls must start with a leading slash") + raise ValueError(f"URL rule '{string}' must start with a slash.") + self.rule = string self.is_leaf = not string.endswith("/") self.is_branch = string.endswith("/") - self.map: "Map" = None # type: ignore + self.map: Map = None # type: ignore self.strict_slashes = strict_slashes self.merge_slashes = merge_slashes self.subdomain = subdomain @@ -489,11 +493,11 @@ class Rule(RuleFactory): else: self.arguments = set() - self._converters: t.Dict[str, "BaseConverter"] = {} - self._trace: t.List[t.Tuple[bool, str]] = [] - self._parts: t.List[RulePart] = [] + self._converters: dict[str, BaseConverter] = {} + self._trace: list[tuple[bool, str]] = [] + self._parts: list[RulePart] = [] - def empty(self) -> "Rule": + def empty(self) -> Rule: """ Return an unbound copy of this rule. @@ -530,7 +534,7 @@ class Rule(RuleFactory): host=self.host, ) - def get_rules(self, map: "Map") -> t.Iterator["Rule"]: + def get_rules(self, map: Map) -> t.Iterator[Rule]: yield self def refresh(self) -> None: @@ -541,7 +545,7 @@ class Rule(RuleFactory): """ self.bind(self.map, rebind=True) - def bind(self, map: "Map", rebind: bool = False) -> None: + def bind(self, map: Map, rebind: bool = False) -> None: """Bind the url to a map and create a regular expression based on the information from the rule itself and the defaults from the map. @@ -564,7 +568,7 @@ class Rule(RuleFactory): converter_name: str, args: t.Tuple, kwargs: t.Mapping[str, t.Any], - ) -> "BaseConverter": + ) -> BaseConverter: """Looks up the converter for the given parameter. .. versionadded:: 0.9 @@ -574,19 +578,20 @@ class Rule(RuleFactory): return self.map.converters[converter_name](self.map, *args, **kwargs) def _encode_query_vars(self, query_vars: t.Mapping[str, t.Any]) -> str: - return url_encode( - query_vars, - charset=self.map.charset, - sort=self.map.sort_parameters, - key=self.map.sort_key, - ) + items: t.Iterable[tuple[str, str]] = iter_multi_items(query_vars) + + if self.map.sort_parameters: + items = sorted(items, key=self.map.sort_key) + + return _urlencode(items) def _parse_rule(self, rule: str) -> t.Iterable[RulePart]: content = "" static = True argument_weights = [] - static_weights: t.List[t.Tuple[int, int]] = [] + static_weights: list[tuple[int, int]] = [] final = False + convertor_number = 0 pos = 0 while pos < len(rule): @@ -613,7 +618,8 @@ class Rule(RuleFactory): self.arguments.add(data["variable"]) if not convobj.part_isolating: final = True - content += f"({convobj.regex})" + content += f"(?P<__werkzeug_{convertor_number}>{convobj.regex})" + convertor_number += 1 argument_weights.append(convobj.weight) self._trace.append((True, data["variable"])) @@ -631,16 +637,27 @@ class Rule(RuleFactory): argument_weights, ) yield RulePart( - content=content, final=final, static=static, weight=weight + content=content, + final=final, + static=static, + suffixed=False, + weight=weight, ) content = "" static = True argument_weights = [] static_weights = [] final = False + convertor_number = 0 pos = match.end() + suffixed = False + if final and content[-1] == "/": + # If a converter is part_isolating=False (matches slashes) and ends with a + # slash, augment the regex to support slash redirects. + suffixed = True + content = content[:-1] + "(? None: """Compiles the regular expression and stores it.""" @@ -665,7 +692,11 @@ class Rule(RuleFactory): if domain_rule == "": self._parts = [ RulePart( - content="", final=False, static=True, weight=Weighting(0, [], 0, []) + content="", + final=False, + static=True, + suffixed=False, + weight=Weighting(0, [], 0, []), ) ] else: @@ -676,24 +707,24 @@ class Rule(RuleFactory): rule = re.sub("/{2,}?", "/", self.rule) self._parts.extend(self._parse_rule(rule)) - self._build: t.Callable[..., t.Tuple[str, str]] + self._build: t.Callable[..., tuple[str, str]] self._build = self._compile_builder(False).__get__(self, None) - self._build_unknown: t.Callable[..., t.Tuple[str, str]] + self._build_unknown: t.Callable[..., tuple[str, str]] self._build_unknown = self._compile_builder(True).__get__(self, None) @staticmethod - def _get_func_code(code: CodeType, name: str) -> t.Callable[..., t.Tuple[str, str]]: - globs: t.Dict[str, t.Any] = {} - locs: t.Dict[str, t.Any] = {} + def _get_func_code(code: CodeType, name: str) -> t.Callable[..., tuple[str, str]]: + globs: dict[str, t.Any] = {} + locs: dict[str, t.Any] = {} exec(code, globs, locs) return locs[name] # type: ignore def _compile_builder( self, append_unknown: bool = True - ) -> t.Callable[..., t.Tuple[str, str]]: + ) -> t.Callable[..., tuple[str, str]]: defaults = self.defaults or {} - dom_ops: t.List[t.Tuple[bool, str]] = [] - url_ops: t.List[t.Tuple[bool, str]] = [] + dom_ops: list[tuple[bool, str]] = [] + url_ops: list[tuple[bool, str]] = [] opl = dom_ops for is_dynamic, data in self._trace: @@ -707,9 +738,8 @@ class Rule(RuleFactory): data = self._converters[data].to_url(defaults[data]) opl.append((False, data)) elif not is_dynamic: - opl.append( - (False, url_quote(_to_bytes(data, self.map.charset), safe="/:|+")) - ) + # safe = https://url.spec.whatwg.org/#url-path-segment-string + opl.append((False, quote(data, safe="!$&'()*+,/:;=@"))) else: opl.append((True, data)) @@ -718,17 +748,17 @@ class Rule(RuleFactory): ret.args = [ast.Name(str(elem), ast.Load())] # type: ignore # str for py2 return ret - def _parts(ops: t.List[t.Tuple[bool, str]]) -> t.List[ast.AST]: + def _parts(ops: list[tuple[bool, str]]) -> list[ast.AST]: parts = [ - _convert(elem) if is_dynamic else ast.Str(s=elem) + _convert(elem) if is_dynamic else ast.Constant(elem) for is_dynamic, elem in ops ] - parts = parts or [ast.Str("")] + parts = parts or [ast.Constant("")] # constant fold ret = [parts[0]] for p in parts[1:]: - if isinstance(p, ast.Str) and isinstance(ret[-1], ast.Str): - ret[-1] = ast.Str(ret[-1].s + p.s) + if isinstance(p, ast.Constant) and isinstance(ret[-1], ast.Constant): + ret[-1] = ast.Constant(ret[-1].value + p.value) else: ret.append(p) return ret @@ -741,7 +771,7 @@ class Rule(RuleFactory): body = [_IF_KWARGS_URL_ENCODE_AST] url_parts.extend(_URL_ENCODE_AST_NAMES) - def _join(parts: t.List[ast.AST]) -> ast.AST: + def _join(parts: list[ast.AST]) -> ast.AST: if len(parts) == 1: # shortcut return parts[0] return ast.JoinedStr(parts) @@ -764,11 +794,11 @@ class Rule(RuleFactory): func_ast.args.args.append(ast.arg(arg, None)) func_ast.args.kwarg = ast.arg(".kwargs", None) for _ in kargs: - func_ast.args.defaults.append(ast.Str("")) + func_ast.args.defaults.append(ast.Constant("")) func_ast.body = body - # use `ast.parse` instead of `ast.Module` for better portability - # Python 3.8 changes the signature of `ast.Module` + # Use `ast.parse` instead of `ast.Module` for better portability, since the + # signature of `ast.Module` can change. module = ast.parse("") module.body = [func_ast] @@ -779,18 +809,18 @@ class Rule(RuleFactory): if "lineno" in node._attributes: node.lineno = 1 if "end_lineno" in node._attributes: - node.end_lineno = node.lineno # type: ignore[attr-defined] + node.end_lineno = node.lineno if "col_offset" in node._attributes: node.col_offset = 0 if "end_col_offset" in node._attributes: - node.end_col_offset = node.col_offset # type: ignore[attr-defined] + node.end_col_offset = node.col_offset code = compile(module, "", "exec") return self._get_func_code(code, func_ast.name) def build( self, values: t.Mapping[str, t.Any], append_unknown: bool = True - ) -> t.Optional[t.Tuple[str, str]]: + ) -> tuple[str, str] | None: """Assembles the relative url for that rule and the subdomain. If building doesn't work for some reasons `None` is returned. @@ -804,7 +834,7 @@ class Rule(RuleFactory): except ValidationError: return None - def provides_defaults_for(self, rule: "Rule") -> bool: + def provides_defaults_for(self, rule: Rule) -> bool: """Check if this rule has defaults for a given rule. :internal: @@ -818,7 +848,7 @@ class Rule(RuleFactory): ) def suitable_for( - self, values: t.Mapping[str, t.Any], method: t.Optional[str] = None + self, values: t.Mapping[str, t.Any], method: str | None = None ) -> bool: """Check if the dict of values has enough data for url generation. @@ -850,7 +880,7 @@ class Rule(RuleFactory): return True - def build_compare_key(self) -> t.Tuple[int, int, int]: + def build_compare_key(self) -> tuple[int, int, int]: """The build compare key for sorting. :internal: diff --git a/src/werkzeug/sansio/http.py b/src/werkzeug/sansio/http.py index 8288882..e3cd333 100644 --- a/src/werkzeug/sansio/http.py +++ b/src/werkzeug/sansio/http.py @@ -1,10 +1,10 @@ +from __future__ import annotations + import re import typing as t from datetime import datetime -from .._internal import _cookie_parse_impl from .._internal import _dt_as_utc -from .._internal import _to_str from ..http import generate_etag from ..http import parse_date from ..http import parse_etags @@ -15,14 +15,14 @@ _etag_re = re.compile(r'([Ww]/)?(?:"(.*?)"|(.*?))(?:\s*,\s*|$)') def is_resource_modified( - http_range: t.Optional[str] = None, - http_if_range: t.Optional[str] = None, - http_if_modified_since: t.Optional[str] = None, - http_if_none_match: t.Optional[str] = None, - http_if_match: t.Optional[str] = None, - etag: t.Optional[str] = None, - data: t.Optional[bytes] = None, - last_modified: t.Optional[t.Union[datetime, str]] = None, + http_range: str | None = None, + http_if_range: str | None = None, + http_if_modified_since: str | None = None, + http_if_none_match: str | None = None, + http_if_match: str | None = None, + etag: str | None = None, + data: bytes | None = None, + last_modified: datetime | str | None = None, ignore_if_range: bool = True, ) -> bool: """Convenience method for conditional requests. @@ -63,7 +63,7 @@ def is_resource_modified( if_range = parse_if_range_header(http_if_range) if if_range is not None and if_range.date is not None: - modified_since: t.Optional[datetime] = if_range.date + modified_since: datetime | None = if_range.date else: modified_since = parse_date(http_if_modified_since) @@ -94,12 +94,36 @@ def is_resource_modified( return not unmodified +_cookie_re = re.compile( + r""" + ([^=;]*) + (?:\s*=\s* + ( + "(?:[^\\"]|\\.)*" + | + .*? + ) + )? + \s*;\s* + """, + flags=re.ASCII | re.VERBOSE, +) +_cookie_unslash_re = re.compile(rb"\\([0-3][0-7]{2}|.)") + + +def _cookie_unslash_replace(m: t.Match[bytes]) -> bytes: + v = m.group(1) + + if len(v) == 1: + return v + + return int(v, 8).to_bytes(1, "big") + + def parse_cookie( - cookie: t.Union[bytes, str, None] = "", - charset: str = "utf-8", - errors: str = "replace", - cls: t.Optional[t.Type["ds.MultiDict"]] = None, -) -> "ds.MultiDict[str, str]": + cookie: str | None = None, + cls: type[ds.MultiDict] | None = None, +) -> ds.MultiDict[str, str]: """Parse a cookie from a string. The same key can be provided multiple times, the values are stored @@ -108,32 +132,39 @@ def parse_cookie( :meth:`MultiDict.getlist`. :param cookie: The cookie header as a string. - :param charset: The charset for the cookie values. - :param errors: The error behavior for the charset decoding. :param cls: A dict-like class to store the parsed cookies in. Defaults to :class:`MultiDict`. + .. versionchanged:: 3.0 + Passing bytes, and the ``charset`` and ``errors`` parameters, were removed. + .. versionadded:: 2.2 """ - # PEP 3333 sends headers through the environ as latin1 decoded - # strings. Encode strings back to bytes for parsing. - if isinstance(cookie, str): - cookie = cookie.encode("latin1", "replace") - if cls is None: cls = ds.MultiDict - def _parse_pairs() -> t.Iterator[t.Tuple[str, str]]: - for key, val in _cookie_parse_impl(cookie): # type: ignore - key_str = _to_str(key, charset, errors, allow_none_charset=True) + if not cookie: + return cls() - if not key_str: - continue + cookie = f"{cookie};" + out = [] - val_str = _to_str(val, charset, errors, allow_none_charset=True) - yield key_str, val_str + for ck, cv in _cookie_re.findall(cookie): + ck = ck.strip() + cv = cv.strip() - return cls(_parse_pairs()) + if not ck: + continue + + if len(cv) >= 2 and cv[0] == cv[-1] == '"': + # Work with bytes here, since a UTF-8 character could be multiple bytes. + cv = _cookie_unslash_re.sub( + _cookie_unslash_replace, cv[1:-1].encode() + ).decode(errors="replace") + + out.append((ck, cv)) + + return cls(out) # circular dependencies diff --git a/src/werkzeug/sansio/multipart.py b/src/werkzeug/sansio/multipart.py index d8abeb3..fc87353 100644 --- a/src/werkzeug/sansio/multipart.py +++ b/src/werkzeug/sansio/multipart.py @@ -1,14 +1,11 @@ +from __future__ import annotations + import re +import typing as t from dataclasses import dataclass from enum import auto from enum import Enum -from typing import cast -from typing import List -from typing import Optional -from typing import Tuple -from .._internal import _to_bytes -from .._internal import _to_str from ..datastructures import Headers from ..exceptions import RequestEntityTooLarge from ..http import parse_options_header @@ -58,6 +55,7 @@ class State(Enum): PREAMBLE = auto() PART = auto() DATA = auto() + DATA_START = auto() EPILOGUE = auto() COMPLETE = auto() @@ -86,11 +84,14 @@ class MultipartDecoder: def __init__( self, boundary: bytes, - max_form_memory_size: Optional[int] = None, + max_form_memory_size: int | None = None, + *, + max_parts: int | None = None, ) -> None: self.buffer = bytearray() self.complete = False self.max_form_memory_size = max_form_memory_size + self.max_parts = max_parts self.state = State.PREAMBLE self.boundary = boundary @@ -118,20 +119,21 @@ class MultipartDecoder: re.MULTILINE, ) self._search_position = 0 + self._parts_decoded = 0 - def last_newline(self) -> int: + def last_newline(self, data: bytes) -> int: try: - last_nl = self.buffer.rindex(b"\n") + last_nl = data.rindex(b"\n") except ValueError: - last_nl = len(self.buffer) + last_nl = len(data) try: - last_cr = self.buffer.rindex(b"\r") + last_cr = data.rindex(b"\r") except ValueError: - last_cr = len(self.buffer) + last_cr = len(data) return min(last_nl, last_cr) - def receive_data(self, data: Optional[bytes]) -> None: + def receive_data(self, data: bytes | None) -> None: if data is None: self.complete = True elif ( @@ -168,7 +170,11 @@ class MultipartDecoder: match = BLANK_LINE_RE.search(self.buffer, self._search_position) if match is not None: headers = self._parse_headers(self.buffer[: match.start()]) - del self.buffer[: match.end()] + # The final header ends with a single CRLF, however a + # blank line indicates the start of the + # body. Therefore the end is after the first CRLF. + headers_end = (match.start() + match.end()) // 2 + del self.buffer[:headers_end] if "content-disposition" not in headers: raise ValueError("Missing Content-Disposition header") @@ -176,7 +182,7 @@ class MultipartDecoder: disposition, extra = parse_options_header( headers["content-disposition"] ) - name = cast(str, extra.get("name")) + name = t.cast(str, extra.get("name")) filename = extra.get("filename") if filename is not None: event = File( @@ -189,36 +195,27 @@ class MultipartDecoder: headers=headers, name=name, ) - self.state = State.DATA + self.state = State.DATA_START self._search_position = 0 + self._parts_decoded += 1 + + if self.max_parts is not None and self._parts_decoded > self.max_parts: + raise RequestEntityTooLarge() else: # Update the search start position to be equal to the # current buffer length (already searched) minus a # safe buffer for part of the search target. self._search_position = max(0, len(self.buffer) - SEARCH_EXTRA_LENGTH) - elif self.state == State.DATA: - if self.buffer.find(b"--" + self.boundary) == -1: - # No complete boundary in the buffer, but there may be - # a partial boundary at the end. As the boundary - # starts with either a nl or cr find the earliest and - # return up to that as data. - data_length = del_index = self.last_newline() - more_data = True - else: - match = self.boundary_re.search(self.buffer) - if match is not None: - if match.group(1).startswith(b"--"): - self.state = State.EPILOGUE - else: - self.state = State.PART - data_length = match.start() - del_index = match.end() - else: - data_length = del_index = self.last_newline() - more_data = match is None + elif self.state == State.DATA_START: + data, del_index, more_data = self._parse_data(self.buffer, start=True) + del self.buffer[:del_index] + event = Data(data=data, more_data=more_data) + if more_data: + self.state = State.DATA - data = bytes(self.buffer[:data_length]) + elif self.state == State.DATA: + data, del_index, more_data = self._parse_data(self.buffer, start=False) del self.buffer[:del_index] if data or not more_data: event = Data(data=data, more_data=more_data) @@ -234,16 +231,56 @@ class MultipartDecoder: return event def _parse_headers(self, data: bytes) -> Headers: - headers: List[Tuple[str, str]] = [] + headers: list[tuple[str, str]] = [] # Merge the continued headers into one line data = HEADER_CONTINUATION_RE.sub(b" ", data) # Now there is one header per line for line in data.splitlines(): - if line.strip() != b"": - name, value = _to_str(line).strip().split(":", 1) + line = line.strip() + + if line != b"": + name, _, value = line.decode().partition(":") headers.append((name.strip(), value.strip())) return Headers(headers) + def _parse_data(self, data: bytes, *, start: bool) -> tuple[bytes, int, bool]: + # Body parts must start with CRLF (or CR or LF) + if start: + match = LINE_BREAK_RE.match(data) + data_start = t.cast(t.Match[bytes], match).end() + else: + data_start = 0 + + boundary = b"--" + self.boundary + + if self.buffer.find(boundary) == -1: + # No complete boundary in the buffer, but there may be + # a partial boundary at the end. As the boundary + # starts with either a nl or cr find the earliest and + # return up to that as data. + data_end = del_index = self.last_newline(data[data_start:]) + data_start + # If amount of data after last newline is far from + # possible length of partial boundary, we should + # assume that there is no partial boundary in the buffer + # and return all pending data. + if (len(data) - data_end) > len(b"\n" + boundary): + data_end = del_index = len(data) + more_data = True + else: + match = self.boundary_re.search(data) + if match is not None: + if match.group(1).startswith(b"--"): + self.state = State.EPILOGUE + else: + self.state = State.PART + data_end = match.start() + del_index = match.end() + else: + data_end = del_index = self.last_newline(data[data_start:]) + data_start + more_data = match is None + + return bytes(data[data_start:data_end]), del_index, more_data + class MultipartEncoder: def __init__(self, boundary: bytes) -> None: @@ -259,17 +296,22 @@ class MultipartEncoder: State.PART, State.DATA, }: - self.state = State.DATA data = b"\r\n--" + self.boundary + b"\r\n" - data += b'Content-Disposition: form-data; name="%s"' % _to_bytes(event.name) + data += b'Content-Disposition: form-data; name="%s"' % event.name.encode() if isinstance(event, File): - data += b'; filename="%s"' % _to_bytes(event.filename) + data += b'; filename="%s"' % event.filename.encode() data += b"\r\n" - for name, value in cast(Field, event).headers: + for name, value in t.cast(Field, event).headers: if name.lower() != "content-disposition": - data += _to_bytes(f"{name}: {value}\r\n") - data += b"\r\n" + data += f"{name}: {value}\r\n".encode() + self.state = State.DATA_START return data + elif isinstance(event, Data) and self.state == State.DATA_START: + self.state = State.DATA + if len(event.data) > 0: + return b"\r\n" + event.data + else: + return event.data elif isinstance(event, Data) and self.state == State.DATA: return event.data elif isinstance(event, Epilogue): diff --git a/src/werkzeug/sansio/request.py b/src/werkzeug/sansio/request.py index 8832baa..b59bd5b 100644 --- a/src/werkzeug/sansio/request.py +++ b/src/werkzeug/sansio/request.py @@ -1,7 +1,9 @@ +from __future__ import annotations + import typing as t from datetime import datetime +from urllib.parse import parse_qsl -from .._internal import _to_str from ..datastructures import Accept from ..datastructures import Authorization from ..datastructures import CharsetAccept @@ -17,7 +19,6 @@ from ..datastructures import MultiDict from ..datastructures import Range from ..datastructures import RequestCacheControl from ..http import parse_accept_header -from ..http import parse_authorization_header from ..http import parse_cache_control_header from ..http import parse_date from ..http import parse_etags @@ -26,11 +27,11 @@ from ..http import parse_list_header from ..http import parse_options_header from ..http import parse_range_header from ..http import parse_set_header -from ..urls import url_decode from ..user_agent import UserAgent from ..utils import cached_property from ..utils import header_property from .http import parse_cookie +from .utils import get_content_length from .utils import get_current_url from .utils import get_host @@ -57,15 +58,13 @@ class Request: :param headers: The headers received with the request. :param remote_addr: The address of the client sending the request. + .. versionchanged:: 3.0 + The ``charset``, ``url_charset``, and ``encoding_errors`` attributes + were removed. + .. versionadded:: 2.0 """ - #: The charset used to decode most data in the request. - charset = "utf-8" - - #: the error handling procedure for errors, defaults to 'replace' - encoding_errors = "replace" - #: the class to use for `args` and `form`. The default is an #: :class:`~werkzeug.datastructures.ImmutableMultiDict` which supports #: multiple values per key. alternatively it makes sense to use an @@ -75,7 +74,7 @@ class Request: #: possible to use mutable structures, but this is not recommended. #: #: .. versionadded:: 0.6 - parameter_storage_class: t.Type[MultiDict] = ImmutableMultiDict + parameter_storage_class: type[MultiDict] = ImmutableMultiDict #: The type to be used for dict values from the incoming WSGI #: environment. (For example for :attr:`cookies`.) By default an @@ -85,16 +84,16 @@ class Request: #: Changed to ``ImmutableMultiDict`` to support multiple values. #: #: .. versionadded:: 0.6 - dict_storage_class: t.Type[MultiDict] = ImmutableMultiDict + dict_storage_class: type[MultiDict] = ImmutableMultiDict #: the type to be used for list values from the incoming WSGI environment. #: By default an :class:`~werkzeug.datastructures.ImmutableList` is used #: (for example for :attr:`access_list`). #: #: .. versionadded:: 0.6 - list_storage_class: t.Type[t.List] = ImmutableList + list_storage_class: type[t.List] = ImmutableList - user_agent_class: t.Type[UserAgent] = UserAgent + user_agent_class: type[UserAgent] = UserAgent """The class used and returned by the :attr:`user_agent` property to parse the header. Defaults to :class:`~werkzeug.user_agent.UserAgent`, which does no parsing. An @@ -114,18 +113,18 @@ class Request: #: the application is being run behind one). #: #: .. versionadded:: 0.9 - trusted_hosts: t.Optional[t.List[str]] = None + trusted_hosts: list[str] | None = None def __init__( self, method: str, scheme: str, - server: t.Optional[t.Tuple[str, t.Optional[int]]], + server: tuple[str, int | None] | None, root_path: str, path: str, query_string: bytes, headers: Headers, - remote_addr: t.Optional[str], + remote_addr: str | None, ) -> None: #: The method the request was made with, such as ``GET``. self.method = method.upper() @@ -157,17 +156,8 @@ class Request: return f"<{type(self).__name__} {url!r} [{self.method}]>" - @property - def url_charset(self) -> str: - """The charset that is assumed for URLs. Defaults to the value - of :attr:`charset`. - - .. versionadded:: 0.6 - """ - return self.charset - @cached_property - def args(self) -> "MultiDict[str, str]": + def args(self) -> MultiDict[str, str]: """The parsed URL parameters (the part in the URL after the question mark). @@ -176,16 +166,20 @@ class Request: is returned from this function. This can be changed by setting :attr:`parameter_storage_class` to a different type. This might be necessary if the order of the form data is important. + + .. versionchanged:: 2.3 + Invalid bytes remain percent encoded. """ - return url_decode( - self.query_string, - self.url_charset, - errors=self.encoding_errors, - cls=self.parameter_storage_class, + return self.parameter_storage_class( + parse_qsl( + self.query_string.decode(), + keep_blank_values=True, + errors="werkzeug.url_quote", + ) ) @cached_property - def access_route(self) -> t.List[str]: + def access_route(self) -> list[str]: """If a forwarded header exists this is a list of all ip addresses from the client ip to the last proxy server. """ @@ -200,7 +194,7 @@ class Request: @cached_property def full_path(self) -> str: """Requested path, including the query string.""" - return f"{self.path}?{_to_str(self.query_string, self.url_charset)}" + return f"{self.path}?{self.query_string.decode()}" @property def is_secure(self) -> bool: @@ -244,15 +238,12 @@ class Request: ) @cached_property - def cookies(self) -> "ImmutableMultiDict[str, str]": + def cookies(self) -> ImmutableMultiDict[str, str]: """A :class:`dict` with the contents of all cookies transmitted with the request.""" wsgi_combined_cookie = ";".join(self.headers.getlist("Cookie")) return parse_cookie( # type: ignore - wsgi_combined_cookie, - self.charset, - self.encoding_errors, - cls=self.dict_storage_class, + wsgi_combined_cookie, cls=self.dict_storage_class ) # Common Descriptors @@ -267,23 +258,16 @@ class Request: ) @cached_property - def content_length(self) -> t.Optional[int]: + def content_length(self) -> int | None: """The Content-Length entity-header field indicates the size of the entity-body in bytes or, in the case of the HEAD method, the size of the entity-body that would have been sent had the request been a GET. """ - if self.headers.get("Transfer-Encoding", "") == "chunked": - return None - - content_length = self.headers.get("Content-Length") - if content_length is not None: - try: - return max(0, int(content_length)) - except (ValueError, TypeError): - pass - - return None + return get_content_length( + http_content_length=self.headers.get("Content-Length"), + http_transfer_encoding=self.headers.get("Transfer-Encoding"), + ) content_encoding = header_property[str]( "Content-Encoding", @@ -358,7 +342,7 @@ class Request: return self._parsed_content_type[0].lower() @property - def mimetype_params(self) -> t.Dict[str, str]: + def mimetype_params(self) -> dict[str, str]: """The mimetype parameters as dict. For example if the content type is ``text/html; charset=utf-8`` the params would be ``{'charset': 'utf-8'}``. @@ -438,7 +422,7 @@ class Request: return parse_etags(self.headers.get("If-None-Match")) @cached_property - def if_modified_since(self) -> t.Optional[datetime]: + def if_modified_since(self) -> datetime | None: """The parsed `If-Modified-Since` header as a datetime object. .. versionchanged:: 2.0 @@ -447,7 +431,7 @@ class Request: return parse_date(self.headers.get("If-Modified-Since")) @cached_property - def if_unmodified_since(self) -> t.Optional[datetime]: + def if_unmodified_since(self) -> datetime | None: """The parsed `If-Unmodified-Since` header as a datetime object. .. versionchanged:: 2.0 @@ -467,7 +451,7 @@ class Request: return parse_if_range_header(self.headers.get("If-Range")) @cached_property - def range(self) -> t.Optional[Range]: + def range(self) -> Range | None: """The parsed `Range` header. .. versionadded:: 0.7 @@ -485,19 +469,24 @@ class Request: :class:`~werkzeug.user_agent.UserAgent` to provide parsing for the other properties or other extended data. - .. versionchanged:: 2.0 - The built in parser is deprecated and will be removed in - Werkzeug 2.1. A ``UserAgent`` subclass must be set to parse - data from the string. + .. versionchanged:: 2.1 + The built-in parser was removed. Set ``user_agent_class`` to a ``UserAgent`` + subclass to parse data from the string. """ return self.user_agent_class(self.headers.get("User-Agent", "")) # Authorization @cached_property - def authorization(self) -> t.Optional[Authorization]: - """The `Authorization` object in parsed form.""" - return parse_authorization_header(self.headers.get("Authorization")) + def authorization(self) -> Authorization | None: + """The ``Authorization`` header parsed into an :class:`.Authorization` object. + ``None`` if the header is not present. + + .. versionchanged:: 2.3 + :class:`Authorization` is no longer a ``dict``. The ``token`` attribute + was added for auth schemes that use a token instead of parameters. + """ + return Authorization.from_header(self.headers.get("Authorization")) # CORS diff --git a/src/werkzeug/sansio/response.py b/src/werkzeug/sansio/response.py index de0bec2..271974e 100644 --- a/src/werkzeug/sansio/response.py +++ b/src/werkzeug/sansio/response.py @@ -1,10 +1,11 @@ +from __future__ import annotations + import typing as t from datetime import datetime from datetime import timedelta from datetime import timezone from http import HTTPStatus -from .._internal import _to_str from ..datastructures import Headers from ..datastructures import HeaderSet from ..http import dump_cookie @@ -28,14 +29,13 @@ from werkzeug.http import parse_csp_header from werkzeug.http import parse_date from werkzeug.http import parse_options_header from werkzeug.http import parse_set_header -from werkzeug.http import parse_www_authenticate_header from werkzeug.http import quote_etag from werkzeug.http import unquote_etag from werkzeug.utils import header_property -def _set_property(name: str, doc: t.Optional[str] = None) -> property: - def fget(self: "Response") -> HeaderSet: +def _set_property(name: str, doc: str | None = None) -> property: + def fget(self: Response) -> HeaderSet: def on_update(header_set: HeaderSet) -> None: if not header_set and name in self.headers: del self.headers[name] @@ -45,10 +45,8 @@ def _set_property(name: str, doc: t.Optional[str] = None) -> property: return parse_set_header(self.headers.get(name), on_update) def fset( - self: "Response", - value: t.Optional[ - t.Union[str, t.Dict[str, t.Union[str, int]], t.Iterable[str]] - ], + self: Response, + value: None | (str | dict[str, str | int] | t.Iterable[str]), ) -> None: if not value: del self.headers[name] @@ -82,17 +80,17 @@ class Response: :param content_type: The full content type of the response. Overrides building the value from ``mimetype``. + .. versionchanged:: 3.0 + The ``charset`` attribute was removed. + .. versionadded:: 2.0 """ - #: the charset of the response. - charset = "utf-8" - #: the default status if none is provided. default_status = 200 #: the default mimetype if none is provided. - default_mimetype: t.Optional[str] = "text/plain" + default_mimetype: str | None = "text/plain" #: Warn if a cookie header exceeds this size. The default, 4093, should be #: safely `supported by most browsers `_. A cookie larger than @@ -109,15 +107,12 @@ class Response: def __init__( self, - status: t.Optional[t.Union[int, str, HTTPStatus]] = None, - headers: t.Optional[ - t.Union[ - t.Mapping[str, t.Union[str, int, t.Iterable[t.Union[str, int]]]], - t.Iterable[t.Tuple[str, t.Union[str, int]]], - ] - ] = None, - mimetype: t.Optional[str] = None, - content_type: t.Optional[str] = None, + status: int | str | HTTPStatus | None = None, + headers: t.Mapping[str, str | t.Iterable[str]] + | t.Iterable[tuple[str, str]] + | None = None, + mimetype: str | None = None, + content_type: str | None = None, ) -> None: if isinstance(headers, Headers): self.headers = headers @@ -130,7 +125,7 @@ class Response: if mimetype is None and "content-type" not in self.headers: mimetype = self.default_mimetype if mimetype is not None: - mimetype = get_content_type(mimetype, self.charset) + mimetype = get_content_type(mimetype, "utf-8") content_type = mimetype if content_type is not None: self.headers["Content-Type"] = content_type @@ -156,30 +151,29 @@ class Response: return self._status @status.setter - def status(self, value: t.Union[str, int, HTTPStatus]) -> None: - if not isinstance(value, (str, bytes, int, HTTPStatus)): - raise TypeError("Invalid status argument") - + def status(self, value: str | int | HTTPStatus) -> None: self._status, self._status_code = self._clean_status(value) - def _clean_status(self, value: t.Union[str, int, HTTPStatus]) -> t.Tuple[str, int]: - if isinstance(value, HTTPStatus): - value = int(value) - status = _to_str(value, self.charset) - split_status = status.split(None, 1) + def _clean_status(self, value: str | int | HTTPStatus) -> tuple[str, int]: + if isinstance(value, (int, HTTPStatus)): + status_code = int(value) + else: + value = value.strip() - if len(split_status) == 0: - raise ValueError("Empty status argument") + if not value: + raise ValueError("Empty status argument") - try: - status_code = int(split_status[0]) - except ValueError: - # only message - return f"0 {status}", 0 + code_str, sep, _ = value.partition(" ") - if len(split_status) > 1: - # code and message - return status, status_code + try: + status_code = int(code_str) + except ValueError: + # only message + return f"0 {value}", 0 + + if sep: + # code and message + return value, status_code # only code, look up message try: @@ -193,13 +187,13 @@ class Response: self, key: str, value: str = "", - max_age: t.Optional[t.Union[timedelta, int]] = None, - expires: t.Optional[t.Union[str, datetime, int, float]] = None, - path: t.Optional[str] = "/", - domain: t.Optional[str] = None, + max_age: timedelta | int | None = None, + expires: str | datetime | int | float | None = None, + path: str | None = "/", + domain: str | None = None, secure: bool = False, httponly: bool = False, - samesite: t.Optional[str] = None, + samesite: str | None = None, ) -> None: """Sets a cookie. @@ -215,7 +209,7 @@ class Response: :param path: limits the cookie to a given path, per default it will span the whole domain. :param domain: if you want to set a cross-domain cookie. For example, - ``domain=".example.com"`` will set a cookie that is + ``domain="example.com"`` will set a cookie that is readable by the domain ``www.example.com``, ``foo.example.com`` etc. Otherwise, a cookie will only be readable by the domain that set it. @@ -236,7 +230,6 @@ class Response: domain=domain, secure=secure, httponly=httponly, - charset=self.charset, max_size=self.max_cookie_size, samesite=samesite, ), @@ -245,11 +238,11 @@ class Response: def delete_cookie( self, key: str, - path: str = "/", - domain: t.Optional[str] = None, + path: str | None = "/", + domain: str | None = None, secure: bool = False, httponly: bool = False, - samesite: t.Optional[str] = None, + samesite: str | None = None, ) -> None: """Delete a cookie. Fails silently if key doesn't exist. @@ -290,7 +283,7 @@ class Response: # Common Descriptors @property - def mimetype(self) -> t.Optional[str]: + def mimetype(self) -> str | None: """The mimetype (content type without charset etc.)""" ct = self.headers.get("content-type") @@ -301,10 +294,10 @@ class Response: @mimetype.setter def mimetype(self, value: str) -> None: - self.headers["Content-Type"] = get_content_type(value, self.charset) + self.headers["Content-Type"] = get_content_type(value, "utf-8") @property - def mimetype_params(self) -> t.Dict[str, str]: + def mimetype_params(self) -> dict[str, str]: """The mimetype parameters as dict. For example if the content type is ``text/html; charset=utf-8`` the params would be ``{'charset': 'utf-8'}``. @@ -421,7 +414,7 @@ class Response: ) @property - def retry_after(self) -> t.Optional[datetime]: + def retry_after(self) -> datetime | None: """The Retry-After response-header field can be used with a 503 (Service Unavailable) response to indicate how long the service is expected to be unavailable to the requesting client. @@ -443,7 +436,7 @@ class Response: return datetime.now(timezone.utc) + timedelta(seconds=seconds) @retry_after.setter - def retry_after(self, value: t.Optional[t.Union[datetime, int, str]]) -> None: + def retry_after(self, value: datetime | int | str | None) -> None: if value is None: if "retry-after" in self.headers: del self.headers["retry-after"] @@ -501,7 +494,7 @@ class Response: """Set the etag, and override the old one if there was one.""" self.headers["ETag"] = quote_etag(etag, weak) - def get_etag(self) -> t.Union[t.Tuple[str, bool], t.Tuple[None, None]]: + def get_etag(self) -> tuple[str, bool] | tuple[None, None]: """Return a tuple in the form ``(etag, is_weak)``. If there is no ETag the return value is ``(None, None)``. """ @@ -542,7 +535,7 @@ class Response: return rv @content_range.setter - def content_range(self, value: t.Optional[t.Union[ContentRange, str]]) -> None: + def content_range(self, value: ContentRange | str | None) -> None: if not value: del self.headers["content-range"] elif isinstance(value, str): @@ -554,16 +547,70 @@ class Response: @property def www_authenticate(self) -> WWWAuthenticate: - """The ``WWW-Authenticate`` header in a parsed form.""" + """The ``WWW-Authenticate`` header parsed into a :class:`.WWWAuthenticate` + object. Modifying the object will modify the header value. - def on_update(www_auth: WWWAuthenticate) -> None: - if not www_auth and "www-authenticate" in self.headers: - del self.headers["www-authenticate"] - elif www_auth: - self.headers["WWW-Authenticate"] = www_auth.to_header() + This header is not set by default. To set this header, assign an instance of + :class:`.WWWAuthenticate` to this attribute. - header = self.headers.get("www-authenticate") - return parse_www_authenticate_header(header, on_update) + .. code-block:: python + + response.www_authenticate = WWWAuthenticate( + "basic", {"realm": "Authentication Required"} + ) + + Multiple values for this header can be sent to give the client multiple options. + Assign a list to set multiple headers. However, modifying the items in the list + will not automatically update the header values, and accessing this attribute + will only ever return the first value. + + To unset this header, assign ``None`` or use ``del``. + + .. versionchanged:: 2.3 + This attribute can be assigned to to set the header. A list can be assigned + to set multiple header values. Use ``del`` to unset the header. + + .. versionchanged:: 2.3 + :class:`WWWAuthenticate` is no longer a ``dict``. The ``token`` attribute + was added for auth challenges that use a token instead of parameters. + """ + value = WWWAuthenticate.from_header(self.headers.get("WWW-Authenticate")) + + if value is None: + value = WWWAuthenticate("basic") + + def on_update(value: WWWAuthenticate) -> None: + self.www_authenticate = value + + value._on_update = on_update + return value + + @www_authenticate.setter + def www_authenticate( + self, value: WWWAuthenticate | list[WWWAuthenticate] | None + ) -> None: + if not value: # None or empty list + del self.www_authenticate + elif isinstance(value, list): + # Clear any existing header by setting the first item. + self.headers.set("WWW-Authenticate", value[0].to_header()) + + for item in value[1:]: + # Add additional header lines for additional items. + self.headers.add("WWW-Authenticate", item.to_header()) + else: + self.headers.set("WWW-Authenticate", value.to_header()) + + def on_update(value: WWWAuthenticate) -> None: + self.www_authenticate = value + + # When setting a single value, allow updating it directly. + value._on_update = on_update + + @www_authenticate.deleter + def www_authenticate(self) -> None: + if "WWW-Authenticate" in self.headers: + del self.headers["WWW-Authenticate"] # CSP @@ -590,7 +637,7 @@ class Response: @content_security_policy.setter def content_security_policy( - self, value: t.Optional[t.Union[ContentSecurityPolicy, str]] + self, value: ContentSecurityPolicy | str | None ) -> None: if not value: del self.headers["content-security-policy"] @@ -625,7 +672,7 @@ class Response: @content_security_policy_report_only.setter def content_security_policy_report_only( - self, value: t.Optional[t.Union[ContentSecurityPolicy, str]] + self, value: ContentSecurityPolicy | str | None ) -> None: if not value: del self.headers["content-security-policy-report-only"] @@ -645,7 +692,7 @@ class Response: return "Access-Control-Allow-Credentials" in self.headers @access_control_allow_credentials.setter - def access_control_allow_credentials(self, value: t.Optional[bool]) -> None: + def access_control_allow_credentials(self, value: bool | None) -> None: if value is True: self.headers["Access-Control-Allow-Credentials"] = "true" else: diff --git a/src/werkzeug/sansio/utils.py b/src/werkzeug/sansio/utils.py index e639dcb..48ec1bf 100644 --- a/src/werkzeug/sansio/utils.py +++ b/src/werkzeug/sansio/utils.py @@ -1,9 +1,11 @@ -import typing as t +from __future__ import annotations -from .._internal import _encode_idna +import typing as t +from urllib.parse import quote + +from .._internal import _plain_int from ..exceptions import SecurityError from ..urls import uri_to_iri -from ..urls import url_quote def host_is_trusted(hostname: str, trusted_list: t.Iterable[str]) -> bool: @@ -18,20 +20,14 @@ def host_is_trusted(hostname: str, trusted_list: t.Iterable[str]) -> bool: if not hostname: return False + try: + hostname = hostname.partition(":")[0].encode("idna").decode("ascii") + except UnicodeEncodeError: + return False + if isinstance(trusted_list, str): trusted_list = [trusted_list] - def _normalize(hostname: str) -> bytes: - if ":" in hostname: - hostname = hostname.rsplit(":", 1)[0] - - return _encode_idna(hostname) - - try: - hostname_bytes = _normalize(hostname) - except UnicodeError: - return False - for ref in trusted_list: if ref.startswith("."): ref = ref[1:] @@ -40,14 +36,11 @@ def host_is_trusted(hostname: str, trusted_list: t.Iterable[str]) -> bool: suffix_match = False try: - ref_bytes = _normalize(ref) - except UnicodeError: + ref = ref.partition(":")[0].encode("idna").decode("ascii") + except UnicodeEncodeError: return False - if ref_bytes == hostname_bytes: - return True - - if suffix_match and hostname_bytes.endswith(b"." + ref_bytes): + if ref == hostname or (suffix_match and hostname.endswith(f".{ref}")): return True return False @@ -55,9 +48,9 @@ def host_is_trusted(hostname: str, trusted_list: t.Iterable[str]) -> bool: def get_host( scheme: str, - host_header: t.Optional[str], - server: t.Optional[t.Tuple[str, t.Optional[int]]] = None, - trusted_hosts: t.Optional[t.Iterable[str]] = None, + host_header: str | None, + server: tuple[str, int | None] | None = None, + trusted_hosts: t.Iterable[str] | None = None, ) -> str: """Return the host for the given parameters. @@ -104,9 +97,9 @@ def get_host( def get_current_url( scheme: str, host: str, - root_path: t.Optional[str] = None, - path: t.Optional[str] = None, - query_string: t.Optional[bytes] = None, + root_path: str | None = None, + path: str | None = None, + query_string: bytes | None = None, ) -> str: """Recreate the URL for a request. If an optional part isn't provided, it and subsequent parts are not included in the URL. @@ -127,39 +120,40 @@ def get_current_url( url.append("/") return uri_to_iri("".join(url)) - url.append(url_quote(root_path.rstrip("/"))) + # safe = https://url.spec.whatwg.org/#url-path-segment-string + # as well as percent for things that are already quoted + url.append(quote(root_path.rstrip("/"), safe="!$&'()*+,/:;=@%")) url.append("/") if path is None: return uri_to_iri("".join(url)) - url.append(url_quote(path.lstrip("/"))) + url.append(quote(path.lstrip("/"), safe="!$&'()*+,/:;=@%")) if query_string: url.append("?") - url.append(url_quote(query_string, safe=":&%=+$!*'(),")) + url.append(quote(query_string, safe="!$&'()*+,/:;=?@%")) return uri_to_iri("".join(url)) def get_content_length( - http_content_length: t.Union[str, None] = None, - http_transfer_encoding: t.Union[str, None] = "", -) -> t.Optional[int]: - """Returns the content length as an integer or ``None`` if - unavailable or chunked transfer encoding is used. + http_content_length: str | None = None, + http_transfer_encoding: str | None = None, +) -> int | None: + """Return the ``Content-Length`` header value as an int. If the header is not given + or the ``Transfer-Encoding`` header is ``chunked``, ``None`` is returned to indicate + a streaming request. If the value is not an integer, or negative, 0 is returned. :param http_content_length: The Content-Length HTTP header. :param http_transfer_encoding: The Transfer-Encoding HTTP header. .. versionadded:: 2.2 """ - if http_transfer_encoding == "chunked": + if http_transfer_encoding == "chunked" or http_content_length is None: return None - if http_content_length is not None: - try: - return max(0, int(http_content_length)) - except (ValueError, TypeError): - pass - return None + try: + return max(0, _plain_int(http_content_length)) + except ValueError: + return 0 diff --git a/src/werkzeug/security.py b/src/werkzeug/security.py index 18d0919..578caf7 100644 --- a/src/werkzeug/security.py +++ b/src/werkzeug/security.py @@ -1,113 +1,130 @@ +from __future__ import annotations + import hashlib import hmac import os import posixpath import secrets -import typing as t - -if t.TYPE_CHECKING: - pass SALT_CHARS = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" -DEFAULT_PBKDF2_ITERATIONS = 260000 +DEFAULT_PBKDF2_ITERATIONS = 600000 -_os_alt_seps: t.List[str] = list( - sep for sep in [os.path.sep, os.path.altsep] if sep is not None and sep != "/" +_os_alt_seps: list[str] = list( + sep for sep in [os.sep, os.path.altsep] if sep is not None and sep != "/" ) def gen_salt(length: int) -> str: """Generate a random string of SALT_CHARS with specified ``length``.""" if length <= 0: - raise ValueError("Salt length must be positive") + raise ValueError("Salt length must be at least 1.") return "".join(secrets.choice(SALT_CHARS) for _ in range(length)) -def _hash_internal(method: str, salt: str, password: str) -> t.Tuple[str, str]: - """Internal password hash helper. Supports plaintext without salt, - unsalted and salted passwords. In case salted passwords are used - hmac is used. - """ - if method == "plain": - return password, method - +def _hash_internal(method: str, salt: str, password: str) -> tuple[str, str]: + method, *args = method.split(":") salt = salt.encode("utf-8") password = password.encode("utf-8") - if method.startswith("pbkdf2:"): - if not salt: - raise ValueError("Salt is required for PBKDF2") + if method == "scrypt": + if not args: + n = 2**15 + r = 8 + p = 1 + else: + try: + n, r, p = map(int, args) + except ValueError: + raise ValueError("'scrypt' takes 3 arguments.") from None - args = method[7:].split(":") - - if len(args) not in (1, 2): - raise ValueError("Invalid number of arguments for PBKDF2") - - method = args.pop(0) - iterations = int(args[0] or 0) if args else DEFAULT_PBKDF2_ITERATIONS + maxmem = 132 * n * r * p # ideally 128, but some extra seems needed return ( - hashlib.pbkdf2_hmac(method, password, salt, iterations).hex(), - f"pbkdf2:{method}:{iterations}", + hashlib.scrypt(password, salt=salt, n=n, r=r, p=p, maxmem=maxmem).hex(), + f"scrypt:{n}:{r}:{p}", ) + elif method == "pbkdf2": + len_args = len(args) - if salt: - return hmac.new(salt, password, method).hexdigest(), method + if len_args == 0: + hash_name = "sha256" + iterations = DEFAULT_PBKDF2_ITERATIONS + elif len_args == 1: + hash_name = args[0] + iterations = DEFAULT_PBKDF2_ITERATIONS + elif len_args == 2: + hash_name = args[0] + iterations = int(args[1]) + else: + raise ValueError("'pbkdf2' takes 2 arguments.") - return hashlib.new(method, password).hexdigest(), method + return ( + hashlib.pbkdf2_hmac(hash_name, password, salt, iterations).hex(), + f"pbkdf2:{hash_name}:{iterations}", + ) + else: + raise ValueError(f"Invalid hash method '{method}'.") def generate_password_hash( - password: str, method: str = "pbkdf2:sha256", salt_length: int = 16 + password: str, method: str = "scrypt", salt_length: int = 16 ) -> str: - """Hash a password with the given method and salt with a string of - the given length. The format of the string returned includes the method - that was used so that :func:`check_password_hash` can check the hash. + """Securely hash a password for storage. A password can be compared to a stored hash + using :func:`check_password_hash`. - The format for the hashed string looks like this:: + The following methods are supported: - method$salt$hash + - ``scrypt``, the default. The parameters are ``n``, ``r``, and ``p``, the default + is ``scrypt:32768:8:1``. See :func:`hashlib.scrypt`. + - ``pbkdf2``, less secure. The parameters are ``hash_method`` and ``iterations``, + the default is ``pbkdf2:sha256:600000``. See :func:`hashlib.pbkdf2_hmac`. - This method can **not** generate unsalted passwords but it is possible - to set param method='plain' in order to enforce plaintext passwords. - If a salt is used, hmac is used internally to salt the password. + Default parameters may be updated to reflect current guidelines, and methods may be + deprecated and removed if they are no longer considered secure. To migrate old + hashes, you may generate a new hash when checking an old hash, or you may contact + users with a link to reset their password. - If PBKDF2 is wanted it can be enabled by setting the method to - ``pbkdf2:method:iterations`` where iterations is optional:: + :param password: The plaintext password. + :param method: The key derivation function and parameters. + :param salt_length: The number of characters to generate for the salt. - pbkdf2:sha256:80000$salt$hash - pbkdf2:sha256$salt$hash + .. versionchanged:: 2.3 + Scrypt support was added. - :param password: the password to hash. - :param method: the hash method to use (one that hashlib supports). Can - optionally be in the format ``pbkdf2:method:iterations`` - to enable PBKDF2. - :param salt_length: the length of the salt in letters. + .. versionchanged:: 2.3 + The default iterations for pbkdf2 was increased to 600,000. + + .. versionchanged:: 2.3 + All plain hashes are deprecated and will not be supported in Werkzeug 3.0. """ - salt = gen_salt(salt_length) if method != "plain" else "" + salt = gen_salt(salt_length) h, actual_method = _hash_internal(method, salt, password) return f"{actual_method}${salt}${h}" def check_password_hash(pwhash: str, password: str) -> bool: - """Check a password against a given salted and hashed password value. - In order to support unsalted legacy passwords this method supports - plain text passwords, md5 and sha1 hashes (both salted and unsalted). + """Securely check that the given stored password hash, previously generated using + :func:`generate_password_hash`, matches the given password. - Returns `True` if the password matched, `False` otherwise. + Methods may be deprecated and removed if they are no longer considered secure. To + migrate old hashes, you may generate a new hash when checking an old hash, or you + may contact users with a link to reset their password. - :param pwhash: a hashed string like returned by - :func:`generate_password_hash`. - :param password: the plaintext password to compare against the hash. + :param pwhash: The hashed password. + :param password: The plaintext password. + + .. versionchanged:: 2.3 + All plain hashes are deprecated and will not be supported in Werkzeug 3.0. """ - if pwhash.count("$") < 2: + try: + method, salt, hashval = pwhash.split("$", 2) + except ValueError: return False - method, salt, hashval = pwhash.split("$", 2) return hmac.compare_digest(_hash_internal(method, salt, password)[0], hashval) -def safe_join(directory: str, *pathnames: str) -> t.Optional[str]: +def safe_join(directory: str, *pathnames: str) -> str | None: """Safely join zero or more untrusted path components to a base directory to avoid escaping the base directory. diff --git a/src/werkzeug/serving.py b/src/werkzeug/serving.py index c482469..ff5eb8c 100644 --- a/src/werkzeug/serving.py +++ b/src/werkzeug/serving.py @@ -11,9 +11,12 @@ It provides features like interactive debugging and code reloading. Use from myapp import create_app from werkzeug import run_simple """ +from __future__ import annotations + import errno import io import os +import selectors import socket import socketserver import sys @@ -23,13 +26,13 @@ from datetime import timedelta from datetime import timezone from http.server import BaseHTTPRequestHandler from http.server import HTTPServer +from urllib.parse import unquote +from urllib.parse import urlsplit from ._internal import _log from ._internal import _wsgi_encoding_dance from .exceptions import InternalServerError from .urls import uri_to_iri -from .urls import url_parse -from .urls import url_unquote try: import ssl @@ -70,11 +73,10 @@ except AttributeError: LISTEN_QUEUE = 128 _TSSLContextArg = t.Optional[ - t.Union["ssl.SSLContext", t.Tuple[str, t.Optional[str]], "te.Literal['adhoc']"] + t.Union["ssl.SSLContext", t.Tuple[str, t.Optional[str]], t.Literal["adhoc"]] ] if t.TYPE_CHECKING: - import typing_extensions as te # noqa: F401 from _typeshed.wsgi import WSGIApplication from _typeshed.wsgi import WSGIEnvironment from cryptography.hazmat.primitives.asymmetric.rsa import ( @@ -148,16 +150,14 @@ class DechunkedInput(io.RawIOBase): class WSGIRequestHandler(BaseHTTPRequestHandler): """A request handler that implements WSGI dispatching.""" - server: "BaseWSGIServer" + server: BaseWSGIServer @property def server_version(self) -> str: # type: ignore - from . import __version__ + return self.server._server_version - return f"Werkzeug/{__version__}" - - def make_environ(self) -> "WSGIEnvironment": - request_url = url_parse(self.path) + def make_environ(self) -> WSGIEnvironment: + request_url = urlsplit(self.path) url_scheme = "http" if self.server.ssl_context is None else "https" if not self.client_address: @@ -173,9 +173,9 @@ class WSGIRequestHandler(BaseHTTPRequestHandler): else: path_info = request_url.path - path_info = url_unquote(path_info) + path_info = unquote(path_info) - environ: "WSGIEnvironment" = { + environ: WSGIEnvironment = { "wsgi.version": (1, 0), "wsgi.url_scheme": url_scheme, "wsgi.input": self.rfile, @@ -201,6 +201,9 @@ class WSGIRequestHandler(BaseHTTPRequestHandler): } for key, value in self.headers.items(): + if "_" in key: + continue + key = key.upper().replace("-", "_") value = value.replace("\r\n", "") if key not in ("CONTENT_TYPE", "CONTENT_LENGTH"): @@ -221,9 +224,7 @@ class WSGIRequestHandler(BaseHTTPRequestHandler): try: # binary_form=False gives nicer information, but wouldn't be compatible with # what Nginx or Apache could return. - peer_cert = self.connection.getpeercert( # type: ignore[attr-defined] - binary_form=True - ) + peer_cert = self.connection.getpeercert(binary_form=True) if peer_cert is not None: # Nginx and Apache use PEM format. environ["SSL_CLIENT_CERT"] = ssl.DER_cert_to_PEM_cert(peer_cert) @@ -241,10 +242,10 @@ class WSGIRequestHandler(BaseHTTPRequestHandler): self.wfile.write(b"HTTP/1.1 100 Continue\r\n\r\n") self.environ = environ = self.make_environ() - status_set: t.Optional[str] = None - headers_set: t.Optional[t.List[t.Tuple[str, str]]] = None - status_sent: t.Optional[str] = None - headers_sent: t.Optional[t.List[t.Tuple[str, str]]] = None + status_set: str | None = None + headers_set: list[tuple[str, str]] | None = None + status_sent: str | None = None + headers_sent: list[tuple[str, str]] | None = None chunk_response: bool = False def write(data: bytes) -> None: @@ -318,7 +319,7 @@ class WSGIRequestHandler(BaseHTTPRequestHandler): headers_set = headers return write - def execute(app: "WSGIApplication") -> None: + def execute(app: WSGIApplication) -> None: application_iter = app(environ, start_response) try: for data in application_iter: @@ -328,8 +329,34 @@ class WSGIRequestHandler(BaseHTTPRequestHandler): if chunk_response: self.wfile.write(b"0\r\n\r\n") finally: + # Check for any remaining data in the read socket, and discard it. This + # will read past request.max_content_length, but lets the client see a + # 413 response instead of a connection reset failure. If we supported + # keep-alive connections, this naive approach would break by reading the + # next request line. Since we know that write (above) closes every + # connection we can read everything. + selector = selectors.DefaultSelector() + selector.register(self.connection, selectors.EVENT_READ) + total_size = 0 + total_reads = 0 + + # A timeout of 0 tends to fail because a client needs a small amount of + # time to continue sending its data. + while selector.select(timeout=0.01): + # Only read 10MB into memory at a time. + data = self.rfile.read(10_000_000) + total_size += len(data) + total_reads += 1 + + # Stop reading on no data, >=10GB, or 1000 reads. If a client sends + # more than that, they'll get a connection reset failure. + if not data or total_size >= 10_000_000_000 or total_reads > 1000: + break + + selector.close() + if hasattr(application_iter, "close"): - application_iter.close() # type: ignore + application_iter.close() try: execute(self.server.app) @@ -370,7 +397,7 @@ class WSGIRequestHandler(BaseHTTPRequestHandler): raise def connection_dropped( - self, error: BaseException, environ: t.Optional["WSGIEnvironment"] = None + self, error: BaseException, environ: WSGIEnvironment | None = None ) -> None: """Called if the connection was closed by the client. By default nothing happens. @@ -396,9 +423,13 @@ class WSGIRequestHandler(BaseHTTPRequestHandler): def port_integer(self) -> int: return self.client_address[1] - def log_request( - self, code: t.Union[int, str] = "-", size: t.Union[int, str] = "-" - ) -> None: + # Escape control characters. This is defined (but private) in Python 3.12. + _control_char_table = str.maketrans( + {c: rf"\x{c:02x}" for c in [*range(0x20), *range(0x7F, 0xA0)]} + ) + _control_char_table[ord("\\")] = r"\\" + + def log_request(self, code: int | str = "-", size: int | str = "-") -> None: try: path = uri_to_iri(self.path) msg = f"{self.command} {path} {self.request_version}" @@ -406,6 +437,8 @@ class WSGIRequestHandler(BaseHTTPRequestHandler): # path isn't set if the requestline was bad msg = self.requestline + # Escape control characters that may be in the decoded path. + msg = msg.translate(self._control_char_table) code = str(code) if code[0] == "1": # 1xx - Informational @@ -459,8 +492,8 @@ def _ansi_style(value: str, *styles: str) -> str: def generate_adhoc_ssl_pair( - cn: t.Optional[str] = None, -) -> t.Tuple["Certificate", "RSAPrivateKeyWithSerialization"]: + cn: str | None = None, +) -> tuple[Certificate, RSAPrivateKeyWithSerialization]: try: from cryptography import x509 from cryptography.x509.oid import NameOID @@ -505,8 +538,8 @@ def generate_adhoc_ssl_pair( def make_ssl_devcert( - base_path: str, host: t.Optional[str] = None, cn: t.Optional[str] = None -) -> t.Tuple[str, str]: + base_path: str, host: str | None = None, cn: str | None = None +) -> tuple[str, str]: """Creates an SSL key for development. This should be used instead of the ``'adhoc'`` key which generates a new cert on each server start. It accepts a path for where it should store the key and cert and @@ -548,7 +581,7 @@ def make_ssl_devcert( return cert_file, pkey_file -def generate_adhoc_ssl_context() -> "ssl.SSLContext": +def generate_adhoc_ssl_context() -> ssl.SSLContext: """Generates an adhoc SSL context for the development server.""" import tempfile import atexit @@ -579,8 +612,8 @@ def generate_adhoc_ssl_context() -> "ssl.SSLContext": def load_ssl_context( - cert_file: str, pkey_file: t.Optional[str] = None, protocol: t.Optional[int] = None -) -> "ssl.SSLContext": + cert_file: str, pkey_file: str | None = None, protocol: int | None = None +) -> ssl.SSLContext: """Loads SSL context from cert/private key files and optional protocol. Many parameters are directly taken from the API of :py:class:`ssl.SSLContext`. @@ -599,7 +632,7 @@ def load_ssl_context( return ctx -def is_ssl_error(error: t.Optional[Exception] = None) -> bool: +def is_ssl_error(error: Exception | None = None) -> bool: """Checks if the given error (or the current one) is an SSL error.""" if error is None: error = t.cast(Exception, sys.exc_info()[1]) @@ -618,11 +651,12 @@ def select_address_family(host: str, port: int) -> socket.AddressFamily: def get_sockaddr( host: str, port: int, family: socket.AddressFamily -) -> t.Union[t.Tuple[str, int], str]: +) -> tuple[str, int] | str: """Return a fully qualified socket address that can be passed to :func:`socket.bind`.""" if family == af_unix: - return host.split("://", 1)[1] + # Absolute path avoids IDNA encoding error when path starts with dot. + return os.path.abspath(host.partition("://")[2]) try: res = socket.getaddrinfo( host, port, family, socket.SOCK_STREAM, socket.IPPROTO_TCP @@ -659,16 +693,17 @@ class BaseWSGIServer(HTTPServer): multithread = False multiprocess = False request_queue_size = LISTEN_QUEUE + allow_reuse_address = True def __init__( self, host: str, port: int, - app: "WSGIApplication", - handler: t.Optional[t.Type[WSGIRequestHandler]] = None, + app: WSGIApplication, + handler: type[WSGIRequestHandler] | None = None, passthrough_errors: bool = False, - ssl_context: t.Optional[_TSSLContextArg] = None, - fd: t.Optional[int] = None, + ssl_context: _TSSLContextArg | None = None, + fd: int | None = None, ) -> None: if handler is None: handler = WSGIRequestHandler @@ -710,10 +745,36 @@ class BaseWSGIServer(HTTPServer): try: self.server_bind() self.server_activate() + except OSError as e: + # Catch connection issues and show them without the traceback. Show + # extra instructions for address not found, and for macOS. + self.server_close() + print(e.strerror, file=sys.stderr) + + if e.errno == errno.EADDRINUSE: + print( + f"Port {port} is in use by another program. Either identify and" + " stop that program, or start the server with a different" + " port.", + file=sys.stderr, + ) + + if sys.platform == "darwin" and port == 5000: + print( + "On macOS, try disabling the 'AirPlay Receiver' service" + " from System Preferences -> General -> AirDrop & Handoff.", + file=sys.stderr, + ) + + sys.exit(1) except BaseException: self.server_close() raise else: + # TCPServer automatically opens a socket even if bind_and_activate is False. + # Close it to silence a ResourceWarning. + self.server_close() + # Use the passed in socket directly. self.socket = socket.fromfd(fd, address_family, socket.SOCK_STREAM) self.server_address = self.socket.getsockname() @@ -729,10 +790,14 @@ class BaseWSGIServer(HTTPServer): ssl_context = generate_adhoc_ssl_context() self.socket = ssl_context.wrap_socket(self.socket, server_side=True) - self.ssl_context: t.Optional["ssl.SSLContext"] = ssl_context + self.ssl_context: ssl.SSLContext | None = ssl_context else: self.ssl_context = None + import importlib.metadata + + self._server_version = f"Werkzeug/{importlib.metadata.version('werkzeug')}" + def log(self, type: str, message: str, *args: t.Any) -> None: _log(type, message, *args) @@ -745,7 +810,7 @@ class BaseWSGIServer(HTTPServer): self.server_close() def handle_error( - self, request: t.Any, client_address: t.Union[t.Tuple[str, int], str] + self, request: t.Any, client_address: tuple[str, int] | str ) -> None: if self.passthrough_errors: raise @@ -811,12 +876,12 @@ class ForkingWSGIServer(ForkingMixIn, BaseWSGIServer): self, host: str, port: int, - app: "WSGIApplication", + app: WSGIApplication, processes: int = 40, - handler: t.Optional[t.Type[WSGIRequestHandler]] = None, + handler: type[WSGIRequestHandler] | None = None, passthrough_errors: bool = False, - ssl_context: t.Optional[_TSSLContextArg] = None, - fd: t.Optional[int] = None, + ssl_context: _TSSLContextArg | None = None, + fd: int | None = None, ) -> None: if not can_fork: raise ValueError("Your platform does not support forking.") @@ -828,13 +893,13 @@ class ForkingWSGIServer(ForkingMixIn, BaseWSGIServer): def make_server( host: str, port: int, - app: "WSGIApplication", + app: WSGIApplication, threaded: bool = False, processes: int = 1, - request_handler: t.Optional[t.Type[WSGIRequestHandler]] = None, + request_handler: type[WSGIRequestHandler] | None = None, passthrough_errors: bool = False, - ssl_context: t.Optional[_TSSLContextArg] = None, - fd: t.Optional[int] = None, + ssl_context: _TSSLContextArg | None = None, + fd: int | None = None, ) -> BaseWSGIServer: """Create an appropriate WSGI server instance based on the value of ``threaded`` and ``processes``. @@ -879,77 +944,23 @@ def is_running_from_reloader() -> bool: return os.environ.get("WERKZEUG_RUN_MAIN") == "true" -def prepare_socket(hostname: str, port: int) -> socket.socket: - """Prepare a socket for use by the WSGI server and reloader. - - The socket is marked inheritable so that it can be kept across - reloads instead of breaking connections. - - Catch errors during bind and show simpler error messages. For - "address already in use", show instructions for resolving the issue, - with special instructions for macOS. - - This is called from :func:`run_simple`, but can be used separately - to control server creation with :func:`make_server`. - """ - address_family = select_address_family(hostname, port) - server_address = get_sockaddr(hostname, port, address_family) - s = socket.socket(address_family, socket.SOCK_STREAM) - s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - s.set_inheritable(True) - - # Remove the socket file if it already exists. - if address_family == af_unix: - server_address = t.cast(str, server_address) - - if os.path.exists(server_address): - os.unlink(server_address) - - # Catch connection issues and show them without the traceback. Show - # extra instructions for address not found, and for macOS. - try: - s.bind(server_address) - except OSError as e: - print(e.strerror, file=sys.stderr) - - if e.errno == errno.EADDRINUSE: - print( - f"Port {port} is in use by another program. Either" - " identify and stop that program, or start the" - " server with a different port.", - file=sys.stderr, - ) - - if sys.platform == "darwin" and port == 5000: - print( - "On macOS, try disabling the 'AirPlay Receiver'" - " service from System Preferences -> Sharing.", - file=sys.stderr, - ) - - sys.exit(1) - - s.listen(LISTEN_QUEUE) - return s - - def run_simple( hostname: str, port: int, - application: "WSGIApplication", + application: WSGIApplication, use_reloader: bool = False, use_debugger: bool = False, use_evalex: bool = True, - extra_files: t.Optional[t.Iterable[str]] = None, - exclude_patterns: t.Optional[t.Iterable[str]] = None, + extra_files: t.Iterable[str] | None = None, + exclude_patterns: t.Iterable[str] | None = None, reloader_interval: int = 1, reloader_type: str = "auto", threaded: bool = False, processes: int = 1, - request_handler: t.Optional[t.Type[WSGIRequestHandler]] = None, - static_files: t.Optional[t.Dict[str, t.Union[str, t.Tuple[str, str]]]] = None, + request_handler: type[WSGIRequestHandler] | None = None, + static_files: dict[str, str | tuple[str, str]] | None = None, passthrough_errors: bool = False, - ssl_context: t.Optional[_TSSLContextArg] = None, + ssl_context: _TSSLContextArg | None = None, ) -> None: """Start a development server for a WSGI application. Various optional features can be enabled. @@ -997,7 +1008,7 @@ def run_simple( serve static files from using :class:`~werkzeug.middleware.SharedDataMiddleware`. :param passthrough_errors: Don't catch unhandled exceptions at the - server level, let the serve crash instead. If ``use_debugger`` + server level, let the server crash instead. If ``use_debugger`` is enabled, the debugger will still catch such errors. :param ssl_context: Configure TLS to serve over HTTPS. Can be an :class:`ssl.SSLContext` object, a ``(cert_file, key_file)`` @@ -1059,12 +1070,7 @@ def run_simple( application = DebuggedApplication(application, evalex=use_evalex) if not is_running_from_reloader(): - s = prepare_socket(hostname, port) - fd = s.fileno() - # Silence a ResourceWarning about an unclosed socket. This object is no longer - # used, the server will create another with fromfd. - s.detach() - os.environ["WERKZEUG_SERVER_FD"] = str(fd) + fd = None else: fd = int(os.environ["WERKZEUG_SERVER_FD"]) @@ -1079,6 +1085,8 @@ def run_simple( ssl_context, fd=fd, ) + srv.socket.set_inheritable(True) + os.environ["WERKZEUG_SERVER_FD"] = str(srv.fileno()) if not is_running_from_reloader(): srv.log_startup() @@ -1087,12 +1095,15 @@ def run_simple( if use_reloader: from ._reloader import run_with_reloader - run_with_reloader( - srv.serve_forever, - extra_files=extra_files, - exclude_patterns=exclude_patterns, - interval=reloader_interval, - reloader_type=reloader_type, - ) + try: + run_with_reloader( + srv.serve_forever, + extra_files=extra_files, + exclude_patterns=exclude_patterns, + interval=reloader_interval, + reloader_type=reloader_type, + ) + finally: + srv.server_close() else: srv.serve_forever() diff --git a/src/werkzeug/test.py b/src/werkzeug/test.py index edb4d4a..7b5899a 100644 --- a/src/werkzeug/test.py +++ b/src/werkzeug/test.py @@ -1,19 +1,21 @@ +from __future__ import annotations + +import dataclasses import mimetypes import sys import typing as t from collections import defaultdict from datetime import datetime -from datetime import timedelta -from http.cookiejar import CookieJar from io import BytesIO from itertools import chain from random import random from tempfile import TemporaryFile from time import time -from urllib.request import Request as _UrllibRequest +from urllib.parse import unquote +from urllib.parse import urlsplit +from urllib.parse import urlunsplit from ._internal import _get_environ -from ._internal import _make_encode_wrapper from ._internal import _wsgi_decoding_dance from ._internal import _wsgi_encoding_dance from .datastructures import Authorization @@ -25,6 +27,8 @@ from .datastructures import Headers from .datastructures import MultiDict from .http import dump_cookie from .http import dump_options_header +from .http import parse_cookie +from .http import parse_date from .http import parse_options_header from .sansio.multipart import Data from .sansio.multipart import Epilogue @@ -32,12 +36,8 @@ from .sansio.multipart import Field from .sansio.multipart import File from .sansio.multipart import MultipartEncoder from .sansio.multipart import Preamble +from .urls import _urlencode from .urls import iri_to_uri -from .urls import url_encode -from .urls import url_fix -from .urls import url_parse -from .urls import url_unparse -from .urls import url_unquote from .utils import cached_property from .utils import get_content_type from .wrappers.request import Request @@ -48,18 +48,21 @@ from .wsgi import get_current_url if t.TYPE_CHECKING: from _typeshed.wsgi import WSGIApplication from _typeshed.wsgi import WSGIEnvironment + import typing_extensions as te def stream_encode_multipart( data: t.Mapping[str, t.Any], use_tempfile: bool = True, threshold: int = 1024 * 500, - boundary: t.Optional[str] = None, - charset: str = "utf-8", -) -> t.Tuple[t.IO[bytes], int, str]: + boundary: str | None = None, +) -> tuple[t.IO[bytes], int, str]: """Encode a dict of values (either strings or file descriptors or :class:`FileStorage` objects.) into a multipart encoded string stored in a file descriptor. + + .. versionchanged:: 3.0 + The ``charset`` parameter was removed. """ if boundary is None: boundary = f"---------------WerkzeugFormPart_{time()}{random()}" @@ -107,7 +110,8 @@ def stream_encode_multipart( and mimetypes.guess_type(filename)[0] or "application/octet-stream" ) - headers = Headers([("Content-Type", content_type)]) + headers = value.headers + headers.update([("Content-Type", content_type)]) if filename is None: write_binary(encoder.send_event(Field(name=key, headers=headers))) else: @@ -120,6 +124,7 @@ def stream_encode_multipart( chunk = reader(16384) if not chunk: + write_binary(encoder.send_event(Data(data=chunk, more_data=False))) break write_binary(encoder.send_event(Data(data=chunk, more_data=True))) @@ -127,9 +132,7 @@ def stream_encode_multipart( if not isinstance(value, str): value = str(value) write_binary(encoder.send_event(Field(name=key, headers=Headers()))) - write_binary( - encoder.send_event(Data(data=value.encode(charset), more_data=False)) - ) + write_binary(encoder.send_event(Data(data=value.encode(), more_data=False))) write_binary(encoder.send_event(Epilogue(data=b""))) @@ -139,87 +142,21 @@ def stream_encode_multipart( def encode_multipart( - values: t.Mapping[str, t.Any], - boundary: t.Optional[str] = None, - charset: str = "utf-8", -) -> t.Tuple[str, bytes]: + values: t.Mapping[str, t.Any], boundary: str | None = None +) -> tuple[str, bytes]: """Like `stream_encode_multipart` but returns a tuple in the form (``boundary``, ``data``) where data is bytes. + + .. versionchanged:: 3.0 + The ``charset`` parameter was removed. """ stream, length, boundary = stream_encode_multipart( - values, use_tempfile=False, boundary=boundary, charset=charset + values, use_tempfile=False, boundary=boundary ) return boundary, stream.read() -class _TestCookieHeaders: - """A headers adapter for cookielib""" - - def __init__(self, headers: t.Union[Headers, t.List[t.Tuple[str, str]]]) -> None: - self.headers = headers - - def getheaders(self, name: str) -> t.Iterable[str]: - headers = [] - name = name.lower() - for k, v in self.headers: - if k.lower() == name: - headers.append(v) - return headers - - def get_all( - self, name: str, default: t.Optional[t.Iterable[str]] = None - ) -> t.Iterable[str]: - headers = self.getheaders(name) - - if not headers: - return default # type: ignore - - return headers - - -class _TestCookieResponse: - """Something that looks like a httplib.HTTPResponse, but is actually just an - adapter for our test responses to make them available for cookielib. - """ - - def __init__(self, headers: t.Union[Headers, t.List[t.Tuple[str, str]]]) -> None: - self.headers = _TestCookieHeaders(headers) - - def info(self) -> _TestCookieHeaders: - return self.headers - - -class _TestCookieJar(CookieJar): - """A cookielib.CookieJar modified to inject and read cookie headers from - and to wsgi environments, and wsgi application responses. - """ - - def inject_wsgi(self, environ: "WSGIEnvironment") -> None: - """Inject the cookies as client headers into the server's wsgi - environment. - """ - cvals = [f"{c.name}={c.value}" for c in self] - - if cvals: - environ["HTTP_COOKIE"] = "; ".join(cvals) - else: - environ.pop("HTTP_COOKIE", None) - - def extract_wsgi( - self, - environ: "WSGIEnvironment", - headers: t.Union[Headers, t.List[t.Tuple[str, str]]], - ) -> None: - """Extract the server's set-cookie headers as cookies into the - cookie jar. - """ - self.extract_cookies( - _TestCookieResponse(headers), # type: ignore - _UrllibRequest(get_current_url(environ)), - ) - - -def _iter_data(data: t.Mapping[str, t.Any]) -> t.Iterator[t.Tuple[str, t.Any]]: +def _iter_data(data: t.Mapping[str, t.Any]) -> t.Iterator[tuple[str, t.Any]]: """Iterate over a mapping that might have a list of values, yielding all key, value pairs. Almost like iter_multi_items but only allows lists, not tuples, of values so tuples can be used for files. @@ -302,11 +239,13 @@ class EnvironBuilder: Serialized with the function assigned to :attr:`json_dumps`. :param environ_base: an optional dict of environment defaults. :param environ_overrides: an optional dict of environment overrides. - :param charset: the charset used to encode string data. :param auth: An authorization object to use for the ``Authorization`` header value. A ``(username, password)`` tuple is a shortcut for ``Basic`` authorization. + .. versionchanged:: 3.0 + The ``charset`` parameter was removed. + .. versionchanged:: 2.1 ``CONTENT_TYPE`` and ``CONTENT_LENGTH`` are not duplicated as header keys in the environ. @@ -350,49 +289,45 @@ class EnvironBuilder: json_dumps = staticmethod(json.dumps) del json - _args: t.Optional[MultiDict] - _query_string: t.Optional[str] - _input_stream: t.Optional[t.IO[bytes]] - _form: t.Optional[MultiDict] - _files: t.Optional[FileMultiDict] + _args: MultiDict | None + _query_string: str | None + _input_stream: t.IO[bytes] | None + _form: MultiDict | None + _files: FileMultiDict | None def __init__( self, path: str = "/", - base_url: t.Optional[str] = None, - query_string: t.Optional[t.Union[t.Mapping[str, str], str]] = None, + base_url: str | None = None, + query_string: t.Mapping[str, str] | str | None = None, method: str = "GET", - input_stream: t.Optional[t.IO[bytes]] = None, - content_type: t.Optional[str] = None, - content_length: t.Optional[int] = None, - errors_stream: t.Optional[t.IO[str]] = None, + input_stream: t.IO[bytes] | None = None, + content_type: str | None = None, + content_length: int | None = None, + errors_stream: t.IO[str] | None = None, multithread: bool = False, multiprocess: bool = False, run_once: bool = False, - headers: t.Optional[t.Union[Headers, t.Iterable[t.Tuple[str, str]]]] = None, - data: t.Optional[ - t.Union[t.IO[bytes], str, bytes, t.Mapping[str, t.Any]] - ] = None, - environ_base: t.Optional[t.Mapping[str, t.Any]] = None, - environ_overrides: t.Optional[t.Mapping[str, t.Any]] = None, - charset: str = "utf-8", - mimetype: t.Optional[str] = None, - json: t.Optional[t.Mapping[str, t.Any]] = None, - auth: t.Optional[t.Union[Authorization, t.Tuple[str, str]]] = None, + headers: Headers | t.Iterable[tuple[str, str]] | None = None, + data: None | (t.IO[bytes] | str | bytes | t.Mapping[str, t.Any]) = None, + environ_base: t.Mapping[str, t.Any] | None = None, + environ_overrides: t.Mapping[str, t.Any] | None = None, + mimetype: str | None = None, + json: t.Mapping[str, t.Any] | None = None, + auth: Authorization | tuple[str, str] | None = None, ) -> None: - path_s = _make_encode_wrapper(path) - if query_string is not None and path_s("?") in path: + if query_string is not None and "?" in path: raise ValueError("Query string is defined in the path and as an argument") - request_uri = url_parse(path) - if query_string is None and path_s("?") in path: + request_uri = urlsplit(path) + if query_string is None and "?" in path: query_string = request_uri.query - self.charset = charset + self.path = iri_to_uri(request_uri.path) self.request_uri = path if base_url is not None: - base_url = url_fix(iri_to_uri(base_url, charset), charset) + base_url = iri_to_uri(base_url) self.base_url = base_url # type: ignore - if isinstance(query_string, (bytes, str)): + if isinstance(query_string, str): self.query_string = query_string else: if query_string is None: @@ -441,15 +376,15 @@ class EnvironBuilder: if input_stream is not None: raise TypeError("can't provide input stream and data") if hasattr(data, "read"): - data = data.read() # type: ignore + data = data.read() if isinstance(data, str): - data = data.encode(self.charset) + data = data.encode() if isinstance(data, bytes): self.input_stream = BytesIO(data) if self.content_length is None: self.content_length = len(data) else: - for key, value in _iter_data(data): # type: ignore + for key, value in _iter_data(data): if isinstance(value, (tuple, dict)) or hasattr(value, "read"): self._add_file_from_data(key, value) else: @@ -459,9 +394,7 @@ class EnvironBuilder: self.mimetype = mimetype @classmethod - def from_environ( - cls, environ: "WSGIEnvironment", **kwargs: t.Any - ) -> "EnvironBuilder": + def from_environ(cls, environ: WSGIEnvironment, **kwargs: t.Any) -> EnvironBuilder: """Turn an environ dict back into a builder. Any extra kwargs override the args extracted from the environ. @@ -496,9 +429,7 @@ class EnvironBuilder: def _add_file_from_data( self, key: str, - value: t.Union[ - t.IO[bytes], t.Tuple[t.IO[bytes], str], t.Tuple[t.IO[bytes], str, str] - ], + value: (t.IO[bytes] | tuple[t.IO[bytes], str] | tuple[t.IO[bytes], str, str]), ) -> None: """Called in the EnvironBuilder to add files from the data dict.""" if isinstance(value, tuple): @@ -508,7 +439,7 @@ class EnvironBuilder: @staticmethod def _make_base_url(scheme: str, host: str, script_root: str) -> str: - return url_unparse((scheme, host, script_root, "", "")).rstrip("/") + "/" + return urlunsplit((scheme, host, script_root, "", "")).rstrip("/") + "/" @property def base_url(self) -> str: @@ -518,13 +449,13 @@ class EnvironBuilder: return self._make_base_url(self.url_scheme, self.host, self.script_root) @base_url.setter - def base_url(self, value: t.Optional[str]) -> None: + def base_url(self, value: str | None) -> None: if value is None: scheme = "http" netloc = "localhost" script_root = "" else: - scheme, netloc, script_root, qs, anchor = url_parse(value) + scheme, netloc, script_root, qs, anchor = urlsplit(value) if qs or anchor: raise ValueError("base url must not contain a query string or fragment") self.script_root = script_root.rstrip("/") @@ -532,7 +463,7 @@ class EnvironBuilder: self.url_scheme = scheme @property - def content_type(self) -> t.Optional[str]: + def content_type(self) -> str | None: """The content type for the request. Reflected from and to the :attr:`headers`. Do not set if you set :attr:`files` or :attr:`form` for auto detection. @@ -547,14 +478,14 @@ class EnvironBuilder: return ct @content_type.setter - def content_type(self, value: t.Optional[str]) -> None: + def content_type(self, value: str | None) -> None: if value is None: self.headers.pop("Content-Type", None) else: self.headers["Content-Type"] = value @property - def mimetype(self) -> t.Optional[str]: + def mimetype(self) -> str | None: """The mimetype (content type without charset etc.) .. versionadded:: 0.14 @@ -564,7 +495,7 @@ class EnvironBuilder: @mimetype.setter def mimetype(self, value: str) -> None: - self.content_type = get_content_type(value, self.charset) + self.content_type = get_content_type(value, "utf-8") @property def mimetype_params(self) -> t.Mapping[str, str]: @@ -582,7 +513,7 @@ class EnvironBuilder: return CallbackDict(d, on_update) @property - def content_length(self) -> t.Optional[int]: + def content_length(self) -> int | None: """The content length as integer. Reflected from and to the :attr:`headers`. Do not set if you set :attr:`files` or :attr:`form` for auto detection. @@ -590,13 +521,13 @@ class EnvironBuilder: return self.headers.get("Content-Length", type=int) @content_length.setter - def content_length(self, value: t.Optional[int]) -> None: + def content_length(self, value: int | None) -> None: if value is None: self.headers.pop("Content-Length", None) else: self.headers["Content-Length"] = str(value) - def _get_form(self, name: str, storage: t.Type[_TAnyMultiDict]) -> _TAnyMultiDict: + def _get_form(self, name: str, storage: type[_TAnyMultiDict]) -> _TAnyMultiDict: """Common behavior for getting the :attr:`form` and :attr:`files` properties. @@ -645,7 +576,7 @@ class EnvironBuilder: self._set_form("_files", value) @property - def input_stream(self) -> t.Optional[t.IO[bytes]]: + def input_stream(self) -> t.IO[bytes] | None: """An optional input stream. This is mutually exclusive with setting :attr:`form` and :attr:`files`, setting it will clear those. Do not provide this if the method is not ``POST`` or @@ -654,7 +585,7 @@ class EnvironBuilder: return self._input_stream @input_stream.setter - def input_stream(self, value: t.Optional[t.IO[bytes]]) -> None: + def input_stream(self, value: t.IO[bytes] | None) -> None: self._input_stream = value self._form = None self._files = None @@ -666,12 +597,12 @@ class EnvironBuilder: """ if self._query_string is None: if self._args is not None: - return url_encode(self._args, charset=self.charset) + return _urlencode(self._args) return "" return self._query_string @query_string.setter - def query_string(self, value: t.Optional[str]) -> None: + def query_string(self, value: str | None) -> None: self._query_string = value self._args = None @@ -685,7 +616,7 @@ class EnvironBuilder: return self._args @args.setter - def args(self, value: t.Optional[MultiDict]) -> None: + def args(self, value: MultiDict | None) -> None: self._query_string = None self._args = value @@ -733,7 +664,7 @@ class EnvironBuilder: pass self.closed = True - def get_environ(self) -> "WSGIEnvironment": + def get_environ(self) -> WSGIEnvironment: """Return the built environ. .. versionchanged:: 0.15 @@ -755,30 +686,30 @@ class EnvironBuilder: content_length = end_pos - start_pos elif mimetype == "multipart/form-data": input_stream, content_length, boundary = stream_encode_multipart( - CombinedMultiDict([self.form, self.files]), charset=self.charset + CombinedMultiDict([self.form, self.files]) ) content_type = f'{mimetype}; boundary="{boundary}"' elif mimetype == "application/x-www-form-urlencoded": - form_encoded = url_encode(self.form, charset=self.charset).encode("ascii") + form_encoded = _urlencode(self.form).encode("ascii") content_length = len(form_encoded) input_stream = BytesIO(form_encoded) else: input_stream = BytesIO() - result: "WSGIEnvironment" = {} + result: WSGIEnvironment = {} if self.environ_base: result.update(self.environ_base) def _path_encode(x: str) -> str: - return _wsgi_encoding_dance(url_unquote(x, self.charset), self.charset) + return _wsgi_encoding_dance(unquote(x)) - raw_uri = _wsgi_encoding_dance(self.request_uri, self.charset) + raw_uri = _wsgi_encoding_dance(self.request_uri) result.update( { "REQUEST_METHOD": self.method, "SCRIPT_NAME": _path_encode(self.script_root), "PATH_INFO": _path_encode(self.path), - "QUERY_STRING": _wsgi_encoding_dance(self.query_string, self.charset), + "QUERY_STRING": _wsgi_encoding_dance(self.query_string), # Non-standard, added by mod_wsgi, uWSGI "REQUEST_URI": raw_uri, # Non-standard, added by gunicorn @@ -821,7 +752,7 @@ class EnvironBuilder: return result - def get_request(self, cls: t.Optional[t.Type[Request]] = None) -> Request: + def get_request(self, cls: type[Request] | None = None) -> Request: """Returns a request with the data. If the request class is not specified :attr:`request_class` is used. @@ -840,24 +771,28 @@ class ClientRedirectError(Exception): class Client: - """This class allows you to send requests to a wrapped application. + """Simulate sending requests to a WSGI application without running a WSGI or HTTP + server. - The use_cookies parameter indicates whether cookies should be stored and - sent for subsequent requests. This is True by default, but passing False - will disable this behaviour. + :param application: The WSGI application to make requests to. + :param response_wrapper: A :class:`.Response` class to wrap response data with. + Defaults to :class:`.TestResponse`. If it's not a subclass of ``TestResponse``, + one will be created. + :param use_cookies: Persist cookies from ``Set-Cookie`` response headers to the + ``Cookie`` header in subsequent requests. Domain and path matching is supported, + but other cookie parameters are ignored. + :param allow_subdomain_redirects: Allow requests to follow redirects to subdomains. + Enable this if the application handles subdomains and redirects between them. - If you want to request some subdomain of your application you may set - `allow_subdomain_redirects` to `True` as if not no external redirects - are allowed. + .. versionchanged:: 2.3 + Simplify cookie implementation, support domain and path matching. .. versionchanged:: 2.1 - Removed deprecated behavior of treating the response as a - tuple. All data is available as properties on the returned - response object. + All data is available as properties on the returned response object. The + response cannot be returned as a tuple. .. versionchanged:: 2.0 - ``response_wrapper`` is always a subclass of - :class:``TestResponse``. + ``response_wrapper`` is always a subclass of :class:``TestResponse``. .. versionchanged:: 0.5 Added the ``use_cookies`` parameter. @@ -865,8 +800,8 @@ class Client: def __init__( self, - application: "WSGIApplication", - response_wrapper: t.Optional[t.Type["Response"]] = None, + application: WSGIApplication, + response_wrapper: type[Response] | None = None, use_cookies: bool = True, allow_subdomain_redirects: bool = False, ) -> None: @@ -884,96 +819,186 @@ class Client: self.response_wrapper = t.cast(t.Type["TestResponse"], response_wrapper) if use_cookies: - self.cookie_jar: t.Optional[_TestCookieJar] = _TestCookieJar() + self._cookies: dict[tuple[str, str, str], Cookie] | None = {} else: - self.cookie_jar = None + self._cookies = None self.allow_subdomain_redirects = allow_subdomain_redirects + def get_cookie( + self, key: str, domain: str = "localhost", path: str = "/" + ) -> Cookie | None: + """Return a :class:`.Cookie` if it exists. Cookies are uniquely identified by + ``(domain, path, key)``. + + :param key: The decoded form of the key for the cookie. + :param domain: The domain the cookie was set for. + :param path: The path the cookie was set for. + + .. versionadded:: 2.3 + """ + if self._cookies is None: + raise TypeError( + "Cookies are disabled. Create a client with 'use_cookies=True'." + ) + + return self._cookies.get((domain, path, key)) + def set_cookie( self, - server_name: str, key: str, value: str = "", - max_age: t.Optional[t.Union[timedelta, int]] = None, - expires: t.Optional[t.Union[str, datetime, int, float]] = None, + *, + domain: str = "localhost", + origin_only: bool = True, path: str = "/", - domain: t.Optional[str] = None, - secure: bool = False, - httponly: bool = False, - samesite: t.Optional[str] = None, - charset: str = "utf-8", + **kwargs: t.Any, ) -> None: - """Sets a cookie in the client's cookie jar. The server name - is required and has to match the one that is also passed to - the open call. + """Set a cookie to be sent in subsequent requests. + + This is a convenience to skip making a test request to a route that would set + the cookie. To test the cookie, make a test request to a route that uses the + cookie value. + + The client uses ``domain``, ``origin_only``, and ``path`` to determine which + cookies to send with a request. It does not use other cookie parameters that + browsers use, since they're not applicable in tests. + + :param key: The key part of the cookie. + :param value: The value part of the cookie. + :param domain: Send this cookie with requests that match this domain. If + ``origin_only`` is true, it must be an exact match, otherwise it may be a + suffix match. + :param origin_only: Whether the domain must be an exact match to the request. + :param path: Send this cookie with requests that match this path either exactly + or as a prefix. + :param kwargs: Passed to :func:`.dump_cookie`. + + .. versionchanged:: 3.0 + The parameter ``server_name`` is removed. The first parameter is + ``key``. Use the ``domain`` and ``origin_only`` parameters instead. + + .. versionchanged:: 2.3 + The ``origin_only`` parameter was added. + + .. versionchanged:: 2.3 + The ``domain`` parameter defaults to ``localhost``. """ - assert self.cookie_jar is not None, "cookies disabled" - header = dump_cookie( - key, - value, - max_age, - expires, - path, - domain, - secure, - httponly, - charset, - samesite=samesite, + if self._cookies is None: + raise TypeError( + "Cookies are disabled. Create a client with 'use_cookies=True'." + ) + + cookie = Cookie._from_response_header( + domain, "/", dump_cookie(key, value, domain=domain, path=path, **kwargs) ) - environ = create_environ(path, base_url=f"http://{server_name}") - headers = [("Set-Cookie", header)] - self.cookie_jar.extract_wsgi(environ, headers) + cookie.origin_only = origin_only + + if cookie._should_delete: + self._cookies.pop(cookie._storage_key, None) + else: + self._cookies[cookie._storage_key] = cookie def delete_cookie( self, - server_name: str, key: str, + *, + domain: str = "localhost", path: str = "/", - domain: t.Optional[str] = None, - secure: bool = False, - httponly: bool = False, - samesite: t.Optional[str] = None, ) -> None: - """Deletes a cookie in the test client.""" - self.set_cookie( - server_name, - key, - expires=0, - max_age=0, - path=path, - domain=domain, - secure=secure, - httponly=httponly, - samesite=samesite, + """Delete a cookie if it exists. Cookies are uniquely identified by + ``(domain, path, key)``. + + :param key: The decoded form of the key for the cookie. + :param domain: The domain the cookie was set for. + :param path: The path the cookie was set for. + + .. versionchanged:: 3.0 + The ``server_name`` parameter is removed. The first parameter is + ``key``. Use the ``domain`` parameter instead. + + .. versionchanged:: 3.0 + The ``secure``, ``httponly`` and ``samesite`` parameters are removed. + + .. versionchanged:: 2.3 + The ``domain`` parameter defaults to ``localhost``. + """ + if self._cookies is None: + raise TypeError( + "Cookies are disabled. Create a client with 'use_cookies=True'." + ) + + self._cookies.pop((domain, path, key), None) + + def _add_cookies_to_wsgi(self, environ: WSGIEnvironment) -> None: + """If cookies are enabled, set the ``Cookie`` header in the environ to the + cookies that are applicable to the request host and path. + + :meta private: + + .. versionadded:: 2.3 + """ + if self._cookies is None: + return + + url = urlsplit(get_current_url(environ)) + server_name = url.hostname or "localhost" + value = "; ".join( + c._to_request_header() + for c in self._cookies.values() + if c._matches_request(server_name, url.path) ) + if value: + environ["HTTP_COOKIE"] = value + else: + environ.pop("HTTP_COOKIE", None) + + def _update_cookies_from_response( + self, server_name: str, path: str, headers: list[str] + ) -> None: + """If cookies are enabled, update the stored cookies from any ``Set-Cookie`` + headers in the response. + + :meta private: + + .. versionadded:: 2.3 + """ + if self._cookies is None: + return + + for header in headers: + cookie = Cookie._from_response_header(server_name, path, header) + + if cookie._should_delete: + self._cookies.pop(cookie._storage_key, None) + else: + self._cookies[cookie._storage_key] = cookie + def run_wsgi_app( - self, environ: "WSGIEnvironment", buffered: bool = False - ) -> t.Tuple[t.Iterable[bytes], str, Headers]: + self, environ: WSGIEnvironment, buffered: bool = False + ) -> tuple[t.Iterable[bytes], str, Headers]: """Runs the wrapped WSGI app with the given environment. :meta private: """ - if self.cookie_jar is not None: - self.cookie_jar.inject_wsgi(environ) - + self._add_cookies_to_wsgi(environ) rv = run_wsgi_app(self.application, environ, buffered=buffered) - - if self.cookie_jar is not None: - self.cookie_jar.extract_wsgi(environ, rv[2]) - + url = urlsplit(get_current_url(environ)) + self._update_cookies_from_response( + url.hostname or "localhost", url.path, rv[2].getlist("Set-Cookie") + ) return rv def resolve_redirect( - self, response: "TestResponse", buffered: bool = False - ) -> "TestResponse": + self, response: TestResponse, buffered: bool = False + ) -> TestResponse: """Perform a new request to the location given by the redirect response to the previous request. :meta private: """ - scheme, netloc, path, qs, anchor = url_parse(response.location) + scheme, netloc, path, qs, anchor = urlsplit(response.location) builder = EnvironBuilder.from_environ( response.request.environ, path=path, query_string=qs ) @@ -1034,7 +1059,7 @@ class Client: buffered: bool = False, follow_redirects: bool = False, **kwargs: t.Any, - ) -> "TestResponse": + ) -> TestResponse: """Generate an environ dict from the given arguments, make a request to the application using it, and return the response. @@ -1052,11 +1077,6 @@ class Client: .. versionchanged:: 2.1 Removed the ``as_tuple`` parameter. - .. versionchanged:: 2.0 - ``as_tuple`` is deprecated and will be removed in Werkzeug - 2.1. Use :attr:`TestResponse.request` and - ``request.environ`` instead. - .. versionchanged:: 2.0 The request input stream is closed when calling ``response.close()``. Input streams for redirects are @@ -1071,7 +1091,7 @@ class Client: .. versionchanged:: 0.5 Added the ``follow_redirects`` parameter. """ - request: t.Optional["Request"] = None + request: Request | None = None if not kwargs and len(args) == 1: arg = args[0] @@ -1095,7 +1115,7 @@ class Client: response = self.response_wrapper(*response, request=request) redirects = set() - history: t.List["TestResponse"] = [] + history: list[TestResponse] = [] if not follow_redirects: return response @@ -1134,42 +1154,42 @@ class Client: response.call_on_close(request.input_stream.close) return response - def get(self, *args: t.Any, **kw: t.Any) -> "TestResponse": + def get(self, *args: t.Any, **kw: t.Any) -> TestResponse: """Call :meth:`open` with ``method`` set to ``GET``.""" kw["method"] = "GET" return self.open(*args, **kw) - def post(self, *args: t.Any, **kw: t.Any) -> "TestResponse": + def post(self, *args: t.Any, **kw: t.Any) -> TestResponse: """Call :meth:`open` with ``method`` set to ``POST``.""" kw["method"] = "POST" return self.open(*args, **kw) - def put(self, *args: t.Any, **kw: t.Any) -> "TestResponse": + def put(self, *args: t.Any, **kw: t.Any) -> TestResponse: """Call :meth:`open` with ``method`` set to ``PUT``.""" kw["method"] = "PUT" return self.open(*args, **kw) - def delete(self, *args: t.Any, **kw: t.Any) -> "TestResponse": + def delete(self, *args: t.Any, **kw: t.Any) -> TestResponse: """Call :meth:`open` with ``method`` set to ``DELETE``.""" kw["method"] = "DELETE" return self.open(*args, **kw) - def patch(self, *args: t.Any, **kw: t.Any) -> "TestResponse": + def patch(self, *args: t.Any, **kw: t.Any) -> TestResponse: """Call :meth:`open` with ``method`` set to ``PATCH``.""" kw["method"] = "PATCH" return self.open(*args, **kw) - def options(self, *args: t.Any, **kw: t.Any) -> "TestResponse": + def options(self, *args: t.Any, **kw: t.Any) -> TestResponse: """Call :meth:`open` with ``method`` set to ``OPTIONS``.""" kw["method"] = "OPTIONS" return self.open(*args, **kw) - def head(self, *args: t.Any, **kw: t.Any) -> "TestResponse": + def head(self, *args: t.Any, **kw: t.Any) -> TestResponse: """Call :meth:`open` with ``method`` set to ``HEAD``.""" kw["method"] = "HEAD" return self.open(*args, **kw) - def trace(self, *args: t.Any, **kw: t.Any) -> "TestResponse": + def trace(self, *args: t.Any, **kw: t.Any) -> TestResponse: """Call :meth:`open` with ``method`` set to ``TRACE``.""" kw["method"] = "TRACE" return self.open(*args, **kw) @@ -1178,7 +1198,7 @@ class Client: return f"<{type(self).__name__} {self.application!r}>" -def create_environ(*args: t.Any, **kwargs: t.Any) -> "WSGIEnvironment": +def create_environ(*args: t.Any, **kwargs: t.Any) -> WSGIEnvironment: """Create a new WSGI environ dict based on the values passed. The first parameter should be the path of the request which defaults to '/'. The second one can either be an absolute path (in that case the host is @@ -1202,8 +1222,8 @@ def create_environ(*args: t.Any, **kwargs: t.Any) -> "WSGIEnvironment": def run_wsgi_app( - app: "WSGIApplication", environ: "WSGIEnvironment", buffered: bool = False -) -> t.Tuple[t.Iterable[bytes], str, Headers]: + app: WSGIApplication, environ: WSGIEnvironment, buffered: bool = False +) -> tuple[t.Iterable[bytes], str, Headers]: """Return a tuple in the form (app_iter, status, headers) of the application output. This works best if you pass it an application that returns an iterator all the time. @@ -1224,8 +1244,8 @@ def run_wsgi_app( # example) don't affect subsequent requests (such as redirects). environ = _get_environ(environ).copy() status: str - response: t.Optional[t.Tuple[str, t.List[t.Tuple[str, str]]]] = None - buffer: t.List[bytes] = [] + response: tuple[str, list[tuple[str, str]]] | None = None + buffer: list[bytes] = [] def start_response(status, headers, exc_info=None): # type: ignore nonlocal response @@ -1290,8 +1310,7 @@ class TestResponse(Response): assumed if missing. .. versionchanged:: 2.1 - Removed deprecated behavior for treating the response instance - as a tuple. + Response instances cannot be treated as tuples. .. versionadded:: 2.0 Test client methods always return instances of this class. @@ -1305,7 +1324,7 @@ class TestResponse(Response): resulted in this response. """ - history: t.Tuple["TestResponse", ...] + history: tuple[TestResponse, ...] """A list of intermediate responses. Populated when the test request is made with ``follow_redirects`` enabled. """ @@ -1319,7 +1338,7 @@ class TestResponse(Response): status: str, headers: Headers, request: Request, - history: t.Tuple["TestResponse"] = (), # type: ignore + history: tuple[TestResponse] = (), # type: ignore **kwargs: t.Any, ) -> None: super().__init__(response, status, headers, **kwargs) @@ -1335,3 +1354,109 @@ class TestResponse(Response): .. versionadded:: 2.1 """ return self.get_data(as_text=True) + + +@dataclasses.dataclass +class Cookie: + """A cookie key, value, and parameters. + + The class itself is not a public API. Its attributes are documented for inspection + with :meth:`.Client.get_cookie` only. + + .. versionadded:: 2.3 + """ + + key: str + """The cookie key, encoded as a client would see it.""" + + value: str + """The cookie key, encoded as a client would see it.""" + + decoded_key: str + """The cookie key, decoded as the application would set and see it.""" + + decoded_value: str + """The cookie value, decoded as the application would set and see it.""" + + expires: datetime | None + """The time at which the cookie is no longer valid.""" + + max_age: int | None + """The number of seconds from when the cookie was set at which it is + no longer valid. + """ + + domain: str + """The domain that the cookie was set for, or the request domain if not set.""" + + origin_only: bool + """Whether the cookie will be sent for exact domain matches only. This is ``True`` + if the ``Domain`` parameter was not present. + """ + + path: str + """The path that the cookie was set for.""" + + secure: bool | None + """The ``Secure`` parameter.""" + + http_only: bool | None + """The ``HttpOnly`` parameter.""" + + same_site: str | None + """The ``SameSite`` parameter.""" + + def _matches_request(self, server_name: str, path: str) -> bool: + return ( + server_name == self.domain + or ( + not self.origin_only + and server_name.endswith(self.domain) + and server_name[: -len(self.domain)].endswith(".") + ) + ) and ( + path == self.path + or ( + path.startswith(self.path) + and path[len(self.path) - self.path.endswith("/") :].startswith("/") + ) + ) + + def _to_request_header(self) -> str: + return f"{self.key}={self.value}" + + @classmethod + def _from_response_header(cls, server_name: str, path: str, header: str) -> te.Self: + header, _, parameters_str = header.partition(";") + key, _, value = header.partition("=") + decoded_key, decoded_value = next(parse_cookie(header).items()) + params = {} + + for item in parameters_str.split(";"): + k, sep, v = item.partition("=") + params[k.strip().lower()] = v.strip() if sep else None + + return cls( + key=key.strip(), + value=value.strip(), + decoded_key=decoded_key, + decoded_value=decoded_value, + expires=parse_date(params.get("expires")), + max_age=int(params["max-age"] or 0) if "max-age" in params else None, + domain=params.get("domain") or server_name, + origin_only="domain" not in params, + path=params.get("path") or path.rpartition("/")[0] or "/", + secure="secure" in params, + http_only="httponly" in params, + same_site=params.get("samesite"), + ) + + @property + def _storage_key(self) -> tuple[str, str, str]: + return self.domain, self.path, self.decoded_key + + @property + def _should_delete(self) -> bool: + return self.max_age == 0 or ( + self.expires is not None and self.expires.timestamp() == 0 + ) diff --git a/src/werkzeug/testapp.py b/src/werkzeug/testapp.py index 0d7ffbb..57f1f6f 100644 --- a/src/werkzeug/testapp.py +++ b/src/werkzeug/testapp.py @@ -1,7 +1,8 @@ """A small application that can be used to test a WSGI server and check it for WSGI compliance. """ -import base64 +from __future__ import annotations + import os import sys import typing as t @@ -13,53 +14,6 @@ from . import __version__ as _werkzeug_version from .wrappers.request import Request from .wrappers.response import Response -if t.TYPE_CHECKING: - from _typeshed.wsgi import StartResponse - from _typeshed.wsgi import WSGIEnvironment - - -logo = Response( - base64.b64decode( - """ -R0lGODlhoACgAOMIAAEDACwpAEpCAGdgAJaKAM28AOnVAP3rAP///////// -//////////////////////yH5BAEKAAgALAAAAACgAKAAAAT+EMlJq704680R+F0ojmRpnuj0rWnrv -nB8rbRs33gu0bzu/0AObxgsGn3D5HHJbCUFyqZ0ukkSDlAidctNFg7gbI9LZlrBaHGtzAae0eloe25 -7w9EDOX2fst/xenyCIn5/gFqDiVVDV4aGeYiKkhSFjnCQY5OTlZaXgZp8nJ2ekaB0SQOjqphrpnOiq -ncEn65UsLGytLVmQ6m4sQazpbtLqL/HwpnER8bHyLrLOc3Oz8PRONPU1crXN9na263dMt/g4SzjMeX -m5yDpLqgG7OzJ4u8lT/P69ej3JPn69kHzN2OIAHkB9RUYSFCFQYQJFTIkCDBiwoXWGnowaLEjRm7+G -p9A7Hhx4rUkAUaSLJlxHMqVMD/aSycSZkyTplCqtGnRAM5NQ1Ly5OmzZc6gO4d6DGAUKA+hSocWYAo -SlM6oUWX2O/o0KdaVU5vuSQLAa0ADwQgMEMB2AIECZhVSnTno6spgbtXmHcBUrQACcc2FrTrWS8wAf -78cMFBgwIBgbN+qvTt3ayikRBk7BoyGAGABAdYyfdzRQGV3l4coxrqQ84GpUBmrdR3xNIDUPAKDBSA -ADIGDhhqTZIWaDcrVX8EsbNzbkvCOxG8bN5w8ly9H8jyTJHC6DFndQydbguh2e/ctZJFXRxMAqqPVA -tQH5E64SPr1f0zz7sQYjAHg0In+JQ11+N2B0XXBeeYZgBZFx4tqBToiTCPv0YBgQv8JqA6BEf6RhXx -w1ENhRBnWV8ctEX4Ul2zc3aVGcQNC2KElyTDYyYUWvShdjDyMOGMuFjqnII45aogPhz/CodUHFwaDx -lTgsaOjNyhGWJQd+lFoAGk8ObghI0kawg+EV5blH3dr+digkYuAGSaQZFHFz2P/cTaLmhF52QeSb45 -Jwxd+uSVGHlqOZpOeJpCFZ5J+rkAkFjQ0N1tah7JJSZUFNsrkeJUJMIBi8jyaEKIhKPomnC91Uo+NB -yyaJ5umnnpInIFh4t6ZSpGaAVmizqjpByDegYl8tPE0phCYrhcMWSv+uAqHfgH88ak5UXZmlKLVJhd -dj78s1Fxnzo6yUCrV6rrDOkluG+QzCAUTbCwf9SrmMLzK6p+OPHx7DF+bsfMRq7Ec61Av9i6GLw23r -idnZ+/OO0a99pbIrJkproCQMA17OPG6suq3cca5ruDfXCCDoS7BEdvmJn5otdqscn+uogRHHXs8cbh -EIfYaDY1AkrC0cqwcZpnM6ludx72x0p7Fo/hZAcpJDjax0UdHavMKAbiKltMWCF3xxh9k25N/Viud8 -ba78iCvUkt+V6BpwMlErmcgc502x+u1nSxJSJP9Mi52awD1V4yB/QHONsnU3L+A/zR4VL/indx/y64 -gqcj+qgTeweM86f0Qy1QVbvmWH1D9h+alqg254QD8HJXHvjQaGOqEqC22M54PcftZVKVSQG9jhkv7C -JyTyDoAJfPdu8v7DRZAxsP/ky9MJ3OL36DJfCFPASC3/aXlfLOOON9vGZZHydGf8LnxYJuuVIbl83y -Az5n/RPz07E+9+zw2A2ahz4HxHo9Kt79HTMx1Q7ma7zAzHgHqYH0SoZWyTuOLMiHwSfZDAQTn0ajk9 -YQqodnUYjByQZhZak9Wu4gYQsMyEpIOAOQKze8CmEF45KuAHTvIDOfHJNipwoHMuGHBnJElUoDmAyX -c2Qm/R8Ah/iILCCJOEokGowdhDYc/yoL+vpRGwyVSCWFYZNljkhEirGXsalWcAgOdeAdoXcktF2udb -qbUhjWyMQxYO01o6KYKOr6iK3fE4MaS+DsvBsGOBaMb0Y6IxADaJhFICaOLmiWTlDAnY1KzDG4ambL -cWBA8mUzjJsN2KjSaSXGqMCVXYpYkj33mcIApyhQf6YqgeNAmNvuC0t4CsDbSshZJkCS1eNisKqlyG -cF8G2JeiDX6tO6Mv0SmjCa3MFb0bJaGPMU0X7c8XcpvMaOQmCajwSeY9G0WqbBmKv34DsMIEztU6Y2 -KiDlFdt6jnCSqx7Dmt6XnqSKaFFHNO5+FmODxMCWBEaco77lNDGXBM0ECYB/+s7nKFdwSF5hgXumQe -EZ7amRg39RHy3zIjyRCykQh8Zo2iviRKyTDn/zx6EefptJj2Cw+Ep2FSc01U5ry4KLPYsTyWnVGnvb -UpyGlhjBUljyjHhWpf8OFaXwhp9O4T1gU9UeyPPa8A2l0p1kNqPXEVRm1AOs1oAGZU596t6SOR2mcB -Oco1srWtkaVrMUzIErrKri85keKqRQYX9VX0/eAUK1hrSu6HMEX3Qh2sCh0q0D2CtnUqS4hj62sE/z -aDs2Sg7MBS6xnQeooc2R2tC9YrKpEi9pLXfYXp20tDCpSP8rKlrD4axprb9u1Df5hSbz9QU0cRpfgn -kiIzwKucd0wsEHlLpe5yHXuc6FrNelOl7pY2+11kTWx7VpRu97dXA3DO1vbkhcb4zyvERYajQgAADs -=""" - ), - mimetype="image/png", -) - - TEMPLATE = """\ @@ -70,7 +24,6 @@ TEMPLATE = """\ body { font-family: 'Lucida Grande', 'Lucida Sans Unicode', 'Geneva', 'Verdana', sans-serif; background-color: white; color: #000; font-size: 15px; text-align: center; } - #logo { float: right; padding: 0 0 10px 10px; } div.box { text-align: left; width: 45em; margin: auto; padding: 50px 0; background-color: white; } h1, h2 { font-family: 'Ubuntu', 'Lucida Grande', 'Lucida Sans Unicode', @@ -92,7 +45,6 @@ TEMPLATE = """\ li.exp { background: white; }
-

WSGI Information

This page displays all available information about the WSGI server and @@ -139,7 +91,7 @@ TEMPLATE = """\ """ -def iter_sys_path() -> t.Iterator[t.Tuple[str, bool, bool]]: +def iter_sys_path() -> t.Iterator[tuple[str, bool, bool]]: if os.name == "posix": def strip(x: str) -> str: @@ -159,7 +111,21 @@ def iter_sys_path() -> t.Iterator[t.Tuple[str, bool, bool]]: yield strip(os.path.normpath(path)), not os.path.isdir(path), path != item -def render_testapp(req: Request) -> bytes: +@Request.application +def test_app(req: Request) -> Response: + """Simple test application that dumps the environment. You can use + it to check if Werkzeug is working properly: + + .. sourcecode:: pycon + + >>> from werkzeug.serving import run_simple + >>> from werkzeug.testapp import test_app + >>> run_simple('localhost', 3000, test_app) + * Running on http://localhost:3000/ + + The application displays important information from the WSGI environment, + the Python interpreter and the installed libraries. + """ try: import pkg_resources except ImportError: @@ -167,7 +133,7 @@ def render_testapp(req: Request) -> bytes: else: eggs = sorted( pkg_resources.working_set, - key=lambda x: x.project_name.lower(), # type: ignore + key=lambda x: x.project_name.lower(), ) python_eggs = [] for egg in eggs: @@ -195,44 +161,18 @@ def render_testapp(req: Request) -> bytes: class_ = f' class="{" ".join(class_)}"' if class_ else "" sys_path.append(f"{escape(item)}") - return ( - TEMPLATE - % { - "python_version": "
".join(escape(sys.version).splitlines()), - "platform": escape(sys.platform), - "os": escape(os.name), - "api_version": sys.api_version, - "byteorder": sys.byteorder, - "werkzeug_version": _werkzeug_version, - "python_eggs": "\n".join(python_eggs), - "wsgi_env": "\n".join(wsgi_env), - "sys_path": "\n".join(sys_path), - } - ).encode("utf-8") - - -def test_app( - environ: "WSGIEnvironment", start_response: "StartResponse" -) -> t.Iterable[bytes]: - """Simple test application that dumps the environment. You can use - it to check if Werkzeug is working properly: - - .. sourcecode:: pycon - - >>> from werkzeug.serving import run_simple - >>> from werkzeug.testapp import test_app - >>> run_simple('localhost', 3000, test_app) - * Running on http://localhost:3000/ - - The application displays important information from the WSGI environment, - the Python interpreter and the installed libraries. - """ - req = Request(environ, populate_request=False) - if req.args.get("resource") == "logo": - response = logo - else: - response = Response(render_testapp(req), mimetype="text/html") - return response(environ, start_response) + context = { + "python_version": "
".join(escape(sys.version).splitlines()), + "platform": escape(sys.platform), + "os": escape(os.name), + "api_version": sys.api_version, + "byteorder": sys.byteorder, + "werkzeug_version": _werkzeug_version, + "python_eggs": "\n".join(python_eggs), + "wsgi_env": "\n".join(wsgi_env), + "sys_path": "\n".join(sys_path), + } + return Response(TEMPLATE % context, mimetype="text/html") if __name__ == "__main__": diff --git a/src/werkzeug/urls.py b/src/werkzeug/urls.py index 67c08b0..4d61e60 100644 --- a/src/werkzeug/urls.py +++ b/src/werkzeug/urls.py @@ -1,722 +1,63 @@ -"""Functions for working with URLs. +from __future__ import annotations -Contains implementations of functions from :mod:`urllib.parse` that -handle bytes and strings. -""" import codecs -import os import re import typing as t +from urllib.parse import quote +from urllib.parse import unquote +from urllib.parse import urlencode +from urllib.parse import urlsplit +from urllib.parse import urlunsplit -from ._internal import _check_str_tuple -from ._internal import _decode_idna -from ._internal import _encode_idna -from ._internal import _make_encode_wrapper -from ._internal import _to_str +from .datastructures import iter_multi_items -if t.TYPE_CHECKING: - from . import datastructures as ds -# A regular expression for what a valid schema looks like -_scheme_re = re.compile(r"^[a-zA-Z0-9+-.]+$") - -# Characters that are safe in any part of an URL. -_always_safe = frozenset( - bytearray( - b"abcdefghijklmnopqrstuvwxyz" - b"ABCDEFGHIJKLMNOPQRSTUVWXYZ" - b"0123456789" - b"-._~" - b"$!'()*+,;" # RFC3986 sub-delims set, not including query string delimiters &= - ) -) - -_hexdigits = "0123456789ABCDEFabcdef" -_hextobyte = { - f"{a}{b}".encode("ascii"): int(f"{a}{b}", 16) - for a in _hexdigits - for b in _hexdigits -} -_bytetohex = [f"%{char:02X}".encode("ascii") for char in range(256)] - - -class _URLTuple(t.NamedTuple): - scheme: str - netloc: str - path: str - query: str - fragment: str - - -class BaseURL(_URLTuple): - """Superclass of :py:class:`URL` and :py:class:`BytesURL`.""" - - __slots__ = () - _at: str - _colon: str - _lbracket: str - _rbracket: str - - def __str__(self) -> str: - return self.to_url() - - def replace(self, **kwargs: t.Any) -> "BaseURL": - """Return an URL with the same values, except for those parameters - given new values by whichever keyword arguments are specified.""" - return self._replace(**kwargs) - - @property - def host(self) -> t.Optional[str]: - """The host part of the URL if available, otherwise `None`. The - host is either the hostname or the IP address mentioned in the - URL. It will not contain the port. - """ - return self._split_host()[0] - - @property - def ascii_host(self) -> t.Optional[str]: - """Works exactly like :attr:`host` but will return a result that - is restricted to ASCII. If it finds a netloc that is not ASCII - it will attempt to idna decode it. This is useful for socket - operations when the URL might include internationalized characters. - """ - rv = self.host - if rv is not None and isinstance(rv, str): - try: - rv = _encode_idna(rv) # type: ignore - except UnicodeError: - rv = rv.encode("ascii", "ignore") # type: ignore - return _to_str(rv, "ascii", "ignore") - - @property - def port(self) -> t.Optional[int]: - """The port in the URL as an integer if it was present, `None` - otherwise. This does not fill in default ports. - """ - try: - rv = int(_to_str(self._split_host()[1])) - if 0 <= rv <= 65535: - return rv - except (ValueError, TypeError): - pass - return None - - @property - def auth(self) -> t.Optional[str]: - """The authentication part in the URL if available, `None` - otherwise. - """ - return self._split_netloc()[0] - - @property - def username(self) -> t.Optional[str]: - """The username if it was part of the URL, `None` otherwise. - This undergoes URL decoding and will always be a string. - """ - rv = self._split_auth()[0] - if rv is not None: - return _url_unquote_legacy(rv) - return None - - @property - def raw_username(self) -> t.Optional[str]: - """The username if it was part of the URL, `None` otherwise. - Unlike :attr:`username` this one is not being decoded. - """ - return self._split_auth()[0] - - @property - def password(self) -> t.Optional[str]: - """The password if it was part of the URL, `None` otherwise. - This undergoes URL decoding and will always be a string. - """ - rv = self._split_auth()[1] - if rv is not None: - return _url_unquote_legacy(rv) - return None - - @property - def raw_password(self) -> t.Optional[str]: - """The password if it was part of the URL, `None` otherwise. - Unlike :attr:`password` this one is not being decoded. - """ - return self._split_auth()[1] - - def decode_query(self, *args: t.Any, **kwargs: t.Any) -> "ds.MultiDict[str, str]": - """Decodes the query part of the URL. Ths is a shortcut for - calling :func:`url_decode` on the query argument. The arguments and - keyword arguments are forwarded to :func:`url_decode` unchanged. - """ - return url_decode(self.query, *args, **kwargs) - - def join(self, *args: t.Any, **kwargs: t.Any) -> "BaseURL": - """Joins this URL with another one. This is just a convenience - function for calling into :meth:`url_join` and then parsing the - return value again. - """ - return url_parse(url_join(self, *args, **kwargs)) - - def to_url(self) -> str: - """Returns a URL string or bytes depending on the type of the - information stored. This is just a convenience function - for calling :meth:`url_unparse` for this URL. - """ - return url_unparse(self) - - def encode_netloc(self) -> str: - """Encodes the netloc part to an ASCII safe URL as bytes.""" - rv = self.ascii_host or "" - if ":" in rv: - rv = f"[{rv}]" - port = self.port - if port is not None: - rv = f"{rv}:{port}" - auth = ":".join( - filter( - None, - [ - url_quote(self.raw_username or "", "utf-8", "strict", "/:%"), - url_quote(self.raw_password or "", "utf-8", "strict", "/:%"), - ], - ) - ) - if auth: - rv = f"{auth}@{rv}" - return rv - - def decode_netloc(self) -> str: - """Decodes the netloc part into a string.""" - rv = _decode_idna(self.host or "") - - if ":" in rv: - rv = f"[{rv}]" - port = self.port - if port is not None: - rv = f"{rv}:{port}" - auth = ":".join( - filter( - None, - [ - _url_unquote_legacy(self.raw_username or "", "/:%@"), - _url_unquote_legacy(self.raw_password or "", "/:%@"), - ], - ) - ) - if auth: - rv = f"{auth}@{rv}" - return rv - - def to_uri_tuple(self) -> "BaseURL": - """Returns a :class:`BytesURL` tuple that holds a URI. This will - encode all the information in the URL properly to ASCII using the - rules a web browser would follow. - - It's usually more interesting to directly call :meth:`iri_to_uri` which - will return a string. - """ - return url_parse(iri_to_uri(self)) - - def to_iri_tuple(self) -> "BaseURL": - """Returns a :class:`URL` tuple that holds a IRI. This will try - to decode as much information as possible in the URL without - losing information similar to how a web browser does it for the - URL bar. - - It's usually more interesting to directly call :meth:`uri_to_iri` which - will return a string. - """ - return url_parse(uri_to_iri(self)) - - def get_file_location( - self, pathformat: t.Optional[str] = None - ) -> t.Tuple[t.Optional[str], t.Optional[str]]: - """Returns a tuple with the location of the file in the form - ``(server, location)``. If the netloc is empty in the URL or - points to localhost, it's represented as ``None``. - - The `pathformat` by default is autodetection but needs to be set - when working with URLs of a specific system. The supported values - are ``'windows'`` when working with Windows or DOS paths and - ``'posix'`` when working with posix paths. - - If the URL does not point to a local file, the server and location - are both represented as ``None``. - - :param pathformat: The expected format of the path component. - Currently ``'windows'`` and ``'posix'`` are - supported. Defaults to ``None`` which is - autodetect. - """ - if self.scheme != "file": - return None, None - - path = url_unquote(self.path) - host = self.netloc or None - - if pathformat is None: - if os.name == "nt": - pathformat = "windows" - else: - pathformat = "posix" - - if pathformat == "windows": - if path[:1] == "/" and path[1:2].isalpha() and path[2:3] in "|:": - path = f"{path[1:2]}:{path[3:]}" - windows_share = path[:3] in ("\\" * 3, "/" * 3) - import ntpath - - path = ntpath.normpath(path) - # Windows shared drives are represented as ``\\host\\directory``. - # That results in a URL like ``file://///host/directory``, and a - # path like ``///host/directory``. We need to special-case this - # because the path contains the hostname. - if windows_share and host is None: - parts = path.lstrip("\\").split("\\", 1) - if len(parts) == 2: - host, path = parts - else: - host = parts[0] - path = "" - elif pathformat == "posix": - import posixpath - - path = posixpath.normpath(path) - else: - raise TypeError(f"Invalid path format {pathformat!r}") - - if host in ("127.0.0.1", "::1", "localhost"): - host = None - - return host, path - - def _split_netloc(self) -> t.Tuple[t.Optional[str], str]: - if self._at in self.netloc: - auth, _, netloc = self.netloc.partition(self._at) - return auth, netloc - return None, self.netloc - - def _split_auth(self) -> t.Tuple[t.Optional[str], t.Optional[str]]: - auth = self._split_netloc()[0] - if not auth: - return None, None - if self._colon not in auth: - return auth, None - - username, _, password = auth.partition(self._colon) - return username, password - - def _split_host(self) -> t.Tuple[t.Optional[str], t.Optional[str]]: - rv = self._split_netloc()[1] - if not rv: - return None, None - - if not rv.startswith(self._lbracket): - if self._colon in rv: - host, _, port = rv.partition(self._colon) - return host, port - return rv, None - - idx = rv.find(self._rbracket) - if idx < 0: - return rv, None - - host = rv[1:idx] - rest = rv[idx + 1 :] - if rest.startswith(self._colon): - return host, rest[1:] - return host, None - - -class URL(BaseURL): - """Represents a parsed URL. This behaves like a regular tuple but - also has some extra attributes that give further insight into the - URL. - """ - - __slots__ = () - _at = "@" - _colon = ":" - _lbracket = "[" - _rbracket = "]" - - def encode(self, charset: str = "utf-8", errors: str = "replace") -> "BytesURL": - """Encodes the URL to a tuple made out of bytes. The charset is - only being used for the path, query and fragment. - """ - return BytesURL( - self.scheme.encode("ascii"), # type: ignore - self.encode_netloc(), - self.path.encode(charset, errors), # type: ignore - self.query.encode(charset, errors), # type: ignore - self.fragment.encode(charset, errors), # type: ignore - ) - - -class BytesURL(BaseURL): - """Represents a parsed URL in bytes.""" - - __slots__ = () - _at = b"@" # type: ignore - _colon = b":" # type: ignore - _lbracket = b"[" # type: ignore - _rbracket = b"]" # type: ignore - - def __str__(self) -> str: - return self.to_url().decode("utf-8", "replace") # type: ignore - - def encode_netloc(self) -> bytes: # type: ignore - """Returns the netloc unchanged as bytes.""" - return self.netloc # type: ignore - - def decode(self, charset: str = "utf-8", errors: str = "replace") -> "URL": - """Decodes the URL to a tuple made out of strings. The charset is - only being used for the path, query and fragment. - """ - return URL( - self.scheme.decode("ascii"), # type: ignore - self.decode_netloc(), - self.path.decode(charset, errors), # type: ignore - self.query.decode(charset, errors), # type: ignore - self.fragment.decode(charset, errors), # type: ignore - ) - - -_unquote_maps: t.Dict[t.FrozenSet[int], t.Dict[bytes, int]] = {frozenset(): _hextobyte} - - -def _unquote_to_bytes( - string: t.Union[str, bytes], unsafe: t.Union[str, bytes] = "" -) -> bytes: - if isinstance(string, str): - string = string.encode("utf-8") - - if isinstance(unsafe, str): - unsafe = unsafe.encode("utf-8") - - unsafe = frozenset(bytearray(unsafe)) - groups = iter(string.split(b"%")) - result = bytearray(next(groups, b"")) - - try: - hex_to_byte = _unquote_maps[unsafe] - except KeyError: - hex_to_byte = _unquote_maps[unsafe] = { - h: b for h, b in _hextobyte.items() if b not in unsafe - } - - for group in groups: - code = group[:2] - - if code in hex_to_byte: - result.append(hex_to_byte[code]) - result.extend(group[2:]) - else: - result.append(37) # % - result.extend(group) - - return bytes(result) - - -def _url_encode_impl( - obj: t.Union[t.Mapping[str, str], t.Iterable[t.Tuple[str, str]]], - charset: str, - sort: bool, - key: t.Optional[t.Callable[[t.Tuple[str, str]], t.Any]], -) -> t.Iterator[str]: - from .datastructures import iter_multi_items - - iterable: t.Iterable[t.Tuple[str, str]] = iter_multi_items(obj) - - if sort: - iterable = sorted(iterable, key=key) - - for key_str, value_str in iterable: - if value_str is None: - continue - - if not isinstance(key_str, bytes): - key_bytes = str(key_str).encode(charset) - else: - key_bytes = key_str - - if not isinstance(value_str, bytes): - value_bytes = str(value_str).encode(charset) - else: - value_bytes = value_str - - yield f"{_fast_url_quote_plus(key_bytes)}={_fast_url_quote_plus(value_bytes)}" - - -def _url_unquote_legacy(value: str, unsafe: str = "") -> str: - try: - return url_unquote(value, charset="utf-8", errors="strict", unsafe=unsafe) - except UnicodeError: - return url_unquote(value, charset="latin1", unsafe=unsafe) - - -def url_parse( - url: str, scheme: t.Optional[str] = None, allow_fragments: bool = True -) -> BaseURL: - """Parses a URL from a string into a :class:`URL` tuple. If the URL - is lacking a scheme it can be provided as second argument. Otherwise, - it is ignored. Optionally fragments can be stripped from the URL - by setting `allow_fragments` to `False`. - - The inverse of this function is :func:`url_unparse`. - - :param url: the URL to parse. - :param scheme: the default schema to use if the URL is schemaless. - :param allow_fragments: if set to `False` a fragment will be removed - from the URL. - """ - s = _make_encode_wrapper(url) - is_text_based = isinstance(url, str) - - if scheme is None: - scheme = s("") - netloc = query = fragment = s("") - i = url.find(s(":")) - if i > 0 and _scheme_re.match(_to_str(url[:i], errors="replace")): - # make sure "iri" is not actually a port number (in which case - # "scheme" is really part of the path) - rest = url[i + 1 :] - if not rest or any(c not in s("0123456789") for c in rest): - # not a port number - scheme, url = url[:i].lower(), rest - - if url[:2] == s("//"): - delim = len(url) - for c in s("/?#"): - wdelim = url.find(c, 2) - if wdelim >= 0: - delim = min(delim, wdelim) - netloc, url = url[2:delim], url[delim:] - if (s("[") in netloc and s("]") not in netloc) or ( - s("]") in netloc and s("[") not in netloc - ): - raise ValueError("Invalid IPv6 URL") - - if allow_fragments and s("#") in url: - url, fragment = url.split(s("#"), 1) - if s("?") in url: - url, query = url.split(s("?"), 1) - - result_type = URL if is_text_based else BytesURL - return result_type(scheme, netloc, url, query, fragment) - - -def _make_fast_url_quote( - charset: str = "utf-8", - errors: str = "strict", - safe: t.Union[str, bytes] = "/:", - unsafe: t.Union[str, bytes] = "", -) -> t.Callable[[bytes], str]: - """Precompile the translation table for a URL encoding function. - - Unlike :func:`url_quote`, the generated function only takes the - string to quote. - - :param charset: The charset to encode the result with. - :param errors: How to handle encoding errors. - :param safe: An optional sequence of safe characters to never encode. - :param unsafe: An optional sequence of unsafe characters to always encode. - """ - if isinstance(safe, str): - safe = safe.encode(charset, errors) - - if isinstance(unsafe, str): - unsafe = unsafe.encode(charset, errors) - - safe = (frozenset(bytearray(safe)) | _always_safe) - frozenset(bytearray(unsafe)) - table = [chr(c) if c in safe else f"%{c:02X}" for c in range(256)] - - def quote(string: bytes) -> str: - return "".join([table[c] for c in string]) - - return quote - - -_fast_url_quote = _make_fast_url_quote() -_fast_quote_plus = _make_fast_url_quote(safe=" ", unsafe="+") - - -def _fast_url_quote_plus(string: bytes) -> str: - return _fast_quote_plus(string).replace(" ", "+") - - -def url_quote( - string: t.Union[str, bytes], - charset: str = "utf-8", - errors: str = "strict", - safe: t.Union[str, bytes] = "/:", - unsafe: t.Union[str, bytes] = "", -) -> str: - """URL encode a single string with a given encoding. - - :param s: the string to quote. - :param charset: the charset to be used. - :param safe: an optional sequence of safe characters. - :param unsafe: an optional sequence of unsafe characters. - - .. versionadded:: 0.9.2 - The `unsafe` parameter was added. - """ - if not isinstance(string, (str, bytes, bytearray)): - string = str(string) - if isinstance(string, str): - string = string.encode(charset, errors) - if isinstance(safe, str): - safe = safe.encode(charset, errors) - if isinstance(unsafe, str): - unsafe = unsafe.encode(charset, errors) - safe = (frozenset(bytearray(safe)) | _always_safe) - frozenset(bytearray(unsafe)) - rv = bytearray() - for char in bytearray(string): - if char in safe: - rv.append(char) - else: - rv.extend(_bytetohex[char]) - return bytes(rv).decode(charset) - - -def url_quote_plus( - string: str, charset: str = "utf-8", errors: str = "strict", safe: str = "" -) -> str: - """URL encode a single string with the given encoding and convert - whitespace to "+". - - :param s: The string to quote. - :param charset: The charset to be used. - :param safe: An optional sequence of safe characters. - """ - return url_quote(string, charset, errors, safe + " ", "+").replace(" ", "+") - - -def url_unparse(components: t.Tuple[str, str, str, str, str]) -> str: - """The reverse operation to :meth:`url_parse`. This accepts arbitrary - as well as :class:`URL` tuples and returns a URL as a string. - - :param components: the parsed URL as tuple which should be converted - into a URL string. - """ - _check_str_tuple(components) - scheme, netloc, path, query, fragment = components - s = _make_encode_wrapper(scheme) - url = s("") - - # We generally treat file:///x and file:/x the same which is also - # what browsers seem to do. This also allows us to ignore a schema - # register for netloc utilization or having to differentiate between - # empty and missing netloc. - if netloc or (scheme and path.startswith(s("/"))): - if path and path[:1] != s("/"): - path = s("/") + path - url = s("//") + (netloc or s("")) + path - elif path: - url += path - if scheme: - url = scheme + s(":") + url - if query: - url = url + s("?") + query - if fragment: - url = url + s("#") + fragment - return url - - -def url_unquote( - s: t.Union[str, bytes], - charset: str = "utf-8", - errors: str = "replace", - unsafe: str = "", -) -> str: - """URL decode a single string with a given encoding. If the charset - is set to `None` no decoding is performed and raw bytes are - returned. - - :param s: the string to unquote. - :param charset: the charset of the query string. If set to `None` - no decoding will take place. - :param errors: the error handling for the charset decoding. - """ - rv = _unquote_to_bytes(s, unsafe) - if charset is None: - return rv - return rv.decode(charset, errors) - - -def url_unquote_plus( - s: t.Union[str, bytes], charset: str = "utf-8", errors: str = "replace" -) -> str: - """URL decode a single string with the given `charset` and decode "+" to - whitespace. - - Per default encoding errors are ignored. If you want a different behavior - you can set `errors` to ``'replace'`` or ``'strict'``. - - :param s: The string to unquote. - :param charset: the charset of the query string. If set to `None` - no decoding will take place. - :param errors: The error handling for the `charset` decoding. - """ - if isinstance(s, str): - s = s.replace("+", " ") - else: - s = s.replace(b"+", b" ") - return url_unquote(s, charset, errors) - - -def url_fix(s: str, charset: str = "utf-8") -> str: - r"""Sometimes you get an URL by a user that just isn't a real URL because - it contains unsafe characters like ' ' and so on. This function can fix - some of the problems in a similar way browsers handle data entered by the - user: - - >>> url_fix('http://de.wikipedia.org/wiki/Elf (Begriffskl\xe4rung)') - 'http://de.wikipedia.org/wiki/Elf%20(Begriffskl%C3%A4rung)' - - :param s: the string with the URL to fix. - :param charset: The target charset for the URL if the url was given - as a string. - """ - # First step is to switch to text processing and to convert - # backslashes (which are invalid in URLs anyways) to slashes. This is - # consistent with what Chrome does. - s = _to_str(s, charset, "replace").replace("\\", "/") - - # For the specific case that we look like a malformed windows URL - # we want to fix this up manually: - if s.startswith("file://") and s[7:8].isalpha() and s[8:10] in (":/", "|/"): - s = f"file:///{s[7:]}" - - url = url_parse(s) - path = url_quote(url.path, charset, safe="/%+$!*'(),") - qs = url_quote_plus(url.query, charset, safe=":&%=+$!*'(),") - anchor = url_quote_plus(url.fragment, charset, safe=":&%=+$!*'(),") - return url_unparse((url.scheme, url.encode_netloc(), path, qs, anchor)) - - -# not-unreserved characters remain quoted when unquoting to IRI -_to_iri_unsafe = "".join([chr(c) for c in range(128) if c not in _always_safe]) - - -def _codec_error_url_quote(e: UnicodeError) -> t.Tuple[str, int]: +def _codec_error_url_quote(e: UnicodeError) -> tuple[str, int]: """Used in :func:`uri_to_iri` after unquoting to re-quote any invalid bytes. """ # the docs state that UnicodeError does have these attributes, # but mypy isn't picking them up - out = _fast_url_quote(e.object[e.start : e.end]) # type: ignore + out = quote(e.object[e.start : e.end], safe="") # type: ignore return out, e.end # type: ignore codecs.register_error("werkzeug.url_quote", _codec_error_url_quote) -def uri_to_iri( - uri: t.Union[str, t.Tuple[str, str, str, str, str]], - charset: str = "utf-8", - errors: str = "werkzeug.url_quote", -) -> str: +def _make_unquote_part(name: str, chars: str) -> t.Callable[[str], str]: + """Create a function that unquotes all percent encoded characters except those + given. This allows working with unquoted characters if possible while not changing + the meaning of a given part of a URL. + """ + choices = "|".join(f"{ord(c):02X}" for c in sorted(chars)) + pattern = re.compile(f"((?:%(?:{choices}))+)", re.I) + + def _unquote_partial(value: str) -> str: + parts = iter(pattern.split(value)) + out = [] + + for part in parts: + out.append(unquote(part, "utf-8", "werkzeug.url_quote")) + out.append(next(parts, "")) + + return "".join(out) + + _unquote_partial.__name__ = f"_unquote_{name}" + return _unquote_partial + + +# characters that should remain quoted in URL parts +# based on https://url.spec.whatwg.org/#percent-encoded-bytes +# always keep all controls, space, and % quoted +_always_unsafe = bytes((*range(0x21), 0x25, 0x7F)).decode() +_unquote_fragment = _make_unquote_part("fragment", _always_unsafe) +_unquote_query = _make_unquote_part("query", _always_unsafe + "&=+#") +_unquote_path = _make_unquote_part("path", _always_unsafe + "/?#") +_unquote_user = _make_unquote_part("user", _always_unsafe + ":@/?#") + + +def uri_to_iri(uri: str) -> str: """Convert a URI to an IRI. All valid UTF-8 characters are unquoted, leaving all reserved and invalid characters quoted. If the URL has a domain, it is decoded from Punycode. @@ -725,9 +66,13 @@ def uri_to_iri( 'http://\\u2603.net/p\\xe5th?q=\\xe8ry%DF' :param uri: The URI to convert. - :param charset: The encoding to encode unquoted bytes with. - :param errors: Error handler to use during ``bytes.encode``. By - default, invalid bytes are left quoted. + + .. versionchanged:: 3.0 + Passing a tuple or bytes, and the ``charset`` and ``errors`` parameters, + are removed. + + .. versionchanged:: 2.3 + Which characters remain quoted is specific to each part of the URL. .. versionchanged:: 0.15 All reserved and invalid characters remain quoted. Previously, @@ -736,26 +81,35 @@ def uri_to_iri( .. versionadded:: 0.6 """ - if isinstance(uri, tuple): - uri = url_unparse(uri) + parts = urlsplit(uri) + path = _unquote_path(parts.path) + query = _unquote_query(parts.query) + fragment = _unquote_fragment(parts.fragment) - uri = url_parse(_to_str(uri, charset)) - path = url_unquote(uri.path, charset, errors, _to_iri_unsafe) - query = url_unquote(uri.query, charset, errors, _to_iri_unsafe) - fragment = url_unquote(uri.fragment, charset, errors, _to_iri_unsafe) - return url_unparse((uri.scheme, uri.decode_netloc(), path, query, fragment)) + if parts.hostname: + netloc = _decode_idna(parts.hostname) + else: + netloc = "" + + if ":" in netloc: + netloc = f"[{netloc}]" + + if parts.port: + netloc = f"{netloc}:{parts.port}" + + if parts.username: + auth = _unquote_user(parts.username) + + if parts.password: + password = _unquote_user(parts.password) + auth = f"{auth}:{password}" + + netloc = f"{auth}@{netloc}" + + return urlunsplit((parts.scheme, netloc, path, query, fragment)) -# reserved characters remain unquoted when quoting to URI -_to_uri_safe = ":/?#[]@!$&'()*+,;=%" - - -def iri_to_uri( - iri: t.Union[str, t.Tuple[str, str, str, str, str]], - charset: str = "utf-8", - errors: str = "strict", - safe_conversion: bool = False, -) -> str: +def iri_to_uri(iri: str) -> str: """Convert an IRI to a URI. All non-ASCII and unsafe characters are quoted. If the URL has a domain, it is encoded to Punycode. @@ -763,305 +117,100 @@ def iri_to_uri( 'http://xn--n3h.net/p%C3%A5th?q=%C3%A8ry%DF' :param iri: The IRI to convert. - :param charset: The encoding of the IRI. - :param errors: Error handler to use during ``bytes.encode``. - :param safe_conversion: Return the URL unchanged if it only contains - ASCII characters and no whitespace. See the explanation below. - There is a general problem with IRI conversion with some protocols - that are in violation of the URI specification. Consider the - following two IRIs:: + .. versionchanged:: 3.0 + Passing a tuple or bytes, the ``charset`` and ``errors`` parameters, + and the ``safe_conversion`` parameter, are removed. - magnet:?xt=uri:whatever - itms-services://?action=download-manifest - - After parsing, we don't know if the scheme requires the ``//``, - which is dropped if empty, but conveys different meanings in the - final URL if it's present or not. In this case, you can use - ``safe_conversion``, which will return the URL unchanged if it only - contains ASCII characters and no whitespace. This can result in a - URI with unquoted characters if it was not already quoted correctly, - but preserves the URL's semantics. Werkzeug uses this for the - ``Location`` header for redirects. + .. versionchanged:: 2.3 + Which characters remain unquoted is specific to each part of the URL. .. versionchanged:: 0.15 - All reserved characters remain unquoted. Previously, only some - reserved characters were left unquoted. + All reserved characters remain unquoted. Previously, only some reserved + characters were left unquoted. .. versionchanged:: 0.9.6 The ``safe_conversion`` parameter was added. .. versionadded:: 0.6 """ - if isinstance(iri, tuple): - iri = url_unparse(iri) + parts = urlsplit(iri) + # safe = https://url.spec.whatwg.org/#url-path-segment-string + # as well as percent for things that are already quoted + path = quote(parts.path, safe="%!$&'()*+,/:;=@") + query = quote(parts.query, safe="%!$&'()*+,/:;=?@") + fragment = quote(parts.fragment, safe="%!#$&'()*+,/:;=?@") - if safe_conversion: - # If we're not sure if it's safe to convert the URL, and it only - # contains ASCII characters, return it unconverted. - try: - native_iri = _to_str(iri) - ascii_iri = native_iri.encode("ascii") - - # Only return if it doesn't have whitespace. (Why?) - if len(ascii_iri.split()) == 1: - return native_iri - except UnicodeError: - pass - - iri = url_parse(_to_str(iri, charset, errors)) - path = url_quote(iri.path, charset, errors, _to_uri_safe) - query = url_quote(iri.query, charset, errors, _to_uri_safe) - fragment = url_quote(iri.fragment, charset, errors, _to_uri_safe) - return url_unparse((iri.scheme, iri.encode_netloc(), path, query, fragment)) - - -def url_decode( - s: t.AnyStr, - charset: str = "utf-8", - include_empty: bool = True, - errors: str = "replace", - separator: str = "&", - cls: t.Optional[t.Type["ds.MultiDict"]] = None, -) -> "ds.MultiDict[str, str]": - """Parse a query string and return it as a :class:`MultiDict`. - - :param s: The query string to parse. - :param charset: Decode bytes to string with this charset. If not - given, bytes are returned as-is. - :param include_empty: Include keys with empty values in the dict. - :param errors: Error handling behavior when decoding bytes. - :param separator: Separator character between pairs. - :param cls: Container to hold result instead of :class:`MultiDict`. - - .. versionchanged:: 2.0 - The ``decode_keys`` parameter is deprecated and will be removed - in Werkzeug 2.1. - - .. versionchanged:: 0.5 - In previous versions ";" and "&" could be used for url decoding. - Now only "&" is supported. If you want to use ";", a different - ``separator`` can be provided. - - .. versionchanged:: 0.5 - The ``cls`` parameter was added. - """ - if cls is None: - from .datastructures import MultiDict # noqa: F811 - - cls = MultiDict - if isinstance(s, str) and not isinstance(separator, str): - separator = separator.decode(charset or "ascii") - elif isinstance(s, bytes) and not isinstance(separator, bytes): - separator = separator.encode(charset or "ascii") # type: ignore - return cls( - _url_decode_impl( - s.split(separator), charset, include_empty, errors # type: ignore - ) - ) - - -def url_decode_stream( - stream: t.IO[bytes], - charset: str = "utf-8", - include_empty: bool = True, - errors: str = "replace", - separator: bytes = b"&", - cls: t.Optional[t.Type["ds.MultiDict"]] = None, - limit: t.Optional[int] = None, -) -> "ds.MultiDict[str, str]": - """Works like :func:`url_decode` but decodes a stream. The behavior - of stream and limit follows functions like - :func:`~werkzeug.wsgi.make_line_iter`. The generator of pairs is - directly fed to the `cls` so you can consume the data while it's - parsed. - - :param stream: a stream with the encoded querystring - :param charset: the charset of the query string. If set to `None` - no decoding will take place. - :param include_empty: Set to `False` if you don't want empty values to - appear in the dict. - :param errors: the decoding error behavior. - :param separator: the pair separator to be used, defaults to ``&`` - :param cls: an optional dict class to use. If this is not specified - or `None` the default :class:`MultiDict` is used. - :param limit: the content length of the URL data. Not necessary if - a limited stream is provided. - - .. versionchanged:: 2.0 - The ``decode_keys`` and ``return_iterator`` parameters are - deprecated and will be removed in Werkzeug 2.1. - - .. versionadded:: 0.8 - """ - from .wsgi import make_chunk_iter - - pair_iter = make_chunk_iter(stream, separator, limit) - decoder = _url_decode_impl(pair_iter, charset, include_empty, errors) - - if cls is None: - from .datastructures import MultiDict # noqa: F811 - - cls = MultiDict - - return cls(decoder) - - -def _url_decode_impl( - pair_iter: t.Iterable[t.AnyStr], charset: str, include_empty: bool, errors: str -) -> t.Iterator[t.Tuple[str, str]]: - for pair in pair_iter: - if not pair: - continue - s = _make_encode_wrapper(pair) - equal = s("=") - if equal in pair: - key, value = pair.split(equal, 1) - else: - if not include_empty: - continue - key = pair - value = s("") - yield ( - url_unquote_plus(key, charset, errors), - url_unquote_plus(value, charset, errors), - ) - - -def url_encode( - obj: t.Union[t.Mapping[str, str], t.Iterable[t.Tuple[str, str]]], - charset: str = "utf-8", - sort: bool = False, - key: t.Optional[t.Callable[[t.Tuple[str, str]], t.Any]] = None, - separator: str = "&", -) -> str: - """URL encode a dict/`MultiDict`. If a value is `None` it will not appear - in the result string. Per default only values are encoded into the target - charset strings. - - :param obj: the object to encode into a query string. - :param charset: the charset of the query string. - :param sort: set to `True` if you want parameters to be sorted by `key`. - :param separator: the separator to be used for the pairs. - :param key: an optional function to be used for sorting. For more details - check out the :func:`sorted` documentation. - - .. versionchanged:: 2.0 - The ``encode_keys`` parameter is deprecated and will be removed - in Werkzeug 2.1. - - .. versionchanged:: 0.5 - Added the ``sort``, ``key``, and ``separator`` parameters. - """ - separator = _to_str(separator, "ascii") - return separator.join(_url_encode_impl(obj, charset, sort, key)) - - -def url_encode_stream( - obj: t.Union[t.Mapping[str, str], t.Iterable[t.Tuple[str, str]]], - stream: t.Optional[t.IO[str]] = None, - charset: str = "utf-8", - sort: bool = False, - key: t.Optional[t.Callable[[t.Tuple[str, str]], t.Any]] = None, - separator: str = "&", -) -> None: - """Like :meth:`url_encode` but writes the results to a stream - object. If the stream is `None` a generator over all encoded - pairs is returned. - - :param obj: the object to encode into a query string. - :param stream: a stream to write the encoded object into or `None` if - an iterator over the encoded pairs should be returned. In - that case the separator argument is ignored. - :param charset: the charset of the query string. - :param sort: set to `True` if you want parameters to be sorted by `key`. - :param separator: the separator to be used for the pairs. - :param key: an optional function to be used for sorting. For more details - check out the :func:`sorted` documentation. - - .. versionchanged:: 2.0 - The ``encode_keys`` parameter is deprecated and will be removed - in Werkzeug 2.1. - - .. versionadded:: 0.8 - """ - separator = _to_str(separator, "ascii") - gen = _url_encode_impl(obj, charset, sort, key) - if stream is None: - return gen # type: ignore - for idx, chunk in enumerate(gen): - if idx: - stream.write(separator) - stream.write(chunk) - return None - - -def url_join( - base: t.Union[str, t.Tuple[str, str, str, str, str]], - url: t.Union[str, t.Tuple[str, str, str, str, str]], - allow_fragments: bool = True, -) -> str: - """Join a base URL and a possibly relative URL to form an absolute - interpretation of the latter. - - :param base: the base URL for the join operation. - :param url: the URL to join. - :param allow_fragments: indicates whether fragments should be allowed. - """ - if isinstance(base, tuple): - base = url_unparse(base) - if isinstance(url, tuple): - url = url_unparse(url) - - _check_str_tuple((base, url)) - s = _make_encode_wrapper(base) - - if not base: - return url - if not url: - return base - - bscheme, bnetloc, bpath, bquery, bfragment = url_parse( - base, allow_fragments=allow_fragments - ) - scheme, netloc, path, query, fragment = url_parse(url, bscheme, allow_fragments) - if scheme != bscheme: - return url - if netloc: - return url_unparse((scheme, netloc, path, query, fragment)) - netloc = bnetloc - - if path[:1] == s("/"): - segments = path.split(s("/")) - elif not path: - segments = bpath.split(s("/")) - if not query: - query = bquery + if parts.hostname: + netloc = parts.hostname.encode("idna").decode("ascii") else: - segments = bpath.split(s("/"))[:-1] + path.split(s("/")) + netloc = "" - # If the rightmost part is "./" we want to keep the slash but - # remove the dot. - if segments[-1] == s("."): - segments[-1] = s("") + if ":" in netloc: + netloc = f"[{netloc}]" - # Resolve ".." and "." - segments = [segment for segment in segments if segment != s(".")] - while True: - i = 1 - n = len(segments) - 1 - while i < n: - if segments[i] == s("..") and segments[i - 1] not in (s(""), s("..")): - del segments[i - 1 : i + 1] - break - i += 1 - else: - break + if parts.port: + netloc = f"{netloc}:{parts.port}" - # Remove trailing ".." if the URL is absolute - unwanted_marker = [s(""), s("..")] - while segments[:2] == unwanted_marker: - del segments[1] + if parts.username: + auth = quote(parts.username, safe="%!$&'()*+,;=") - path = s("/").join(segments) - return url_unparse((scheme, netloc, path, query, fragment)) + if parts.password: + password = quote(parts.password, safe="%!$&'()*+,;=") + auth = f"{auth}:{password}" + + netloc = f"{auth}@{netloc}" + + return urlunsplit((parts.scheme, netloc, path, query, fragment)) + + +def _invalid_iri_to_uri(iri: str) -> str: + """The URL scheme ``itms-services://`` must contain the ``//`` even though it does + not have a host component. There may be other invalid schemes as well. Currently, + responses will always call ``iri_to_uri`` on the redirect ``Location`` header, which + removes the ``//``. For now, if the IRI only contains ASCII and does not contain + spaces, pass it on as-is. In Werkzeug 3.0, this should become a + ``response.process_location`` flag. + + :meta private: + """ + try: + iri.encode("ascii") + except UnicodeError: + pass + else: + if len(iri.split(None, 1)) == 1: + return iri + + return iri_to_uri(iri) + + +def _decode_idna(domain: str) -> str: + try: + data = domain.encode("ascii") + except UnicodeEncodeError: + # If the domain is not ASCII, it's decoded already. + return domain + + try: + # Try decoding in one shot. + return data.decode("idna") + except UnicodeDecodeError: + pass + + # Decode each part separately, leaving invalid parts as punycode. + parts = [] + + for part in data.split(b"."): + try: + parts.append(part.decode("idna")) + except UnicodeDecodeError: + parts.append(part.decode("ascii")) + + return ".".join(parts) + + +def _urlencode(query: t.Mapping[str, str] | t.Iterable[tuple[str, str]]) -> str: + items = [x for x in iter_multi_items(query) if x[1] is not None] + # safe = https://url.spec.whatwg.org/#percent-encoded-bytes + return urlencode(items, safe="!$'()*,/:;?@") diff --git a/src/werkzeug/user_agent.py b/src/werkzeug/user_agent.py index 66ffcbe..17e5d3f 100644 --- a/src/werkzeug/user_agent.py +++ b/src/werkzeug/user_agent.py @@ -1,4 +1,4 @@ -import typing as t +from __future__ import annotations class UserAgent: @@ -17,16 +17,16 @@ class UserAgent: provide a built-in parser. """ - platform: t.Optional[str] = None + platform: str | None = None """The OS name, if it could be parsed from the string.""" - browser: t.Optional[str] = None + browser: str | None = None """The browser name, if it could be parsed from the string.""" - version: t.Optional[str] = None + version: str | None = None """The browser version, if it could be parsed from the string.""" - language: t.Optional[str] = None + language: str | None = None """The browser language, if it could be parsed from the string.""" def __init__(self, string: str) -> None: diff --git a/src/werkzeug/utils.py b/src/werkzeug/utils.py index 672e6e5..785ac28 100644 --- a/src/werkzeug/utils.py +++ b/src/werkzeug/utils.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import io import mimetypes import os @@ -8,6 +10,7 @@ import typing as t import unicodedata from datetime import datetime from time import time +from urllib.parse import quote from zlib import adler32 from markupsafe import escape @@ -19,7 +22,6 @@ from .datastructures import Headers from .exceptions import NotFound from .exceptions import RequestedRangeNotSatisfiable from .security import safe_join -from .urls import url_quote from .wsgi import wrap_file if t.TYPE_CHECKING: @@ -31,19 +33,14 @@ _T = t.TypeVar("_T") _entity_re = re.compile(r"&([^;]+);") _filename_ascii_strip_re = re.compile(r"[^A-Za-z0-9_.-]") -_windows_device_files = ( +_windows_device_files = { "CON", - "AUX", - "COM1", - "COM2", - "COM3", - "COM4", - "LPT1", - "LPT2", - "LPT3", "PRN", + "AUX", "NUL", -) + *(f"COM{i}" for i in range(10)), + *(f"LPT{i}" for i in range(10)), +} class cached_property(property, t.Generic[_T]): @@ -80,8 +77,8 @@ class cached_property(property, t.Generic[_T]): def __init__( self, fget: t.Callable[[t.Any], _T], - name: t.Optional[str] = None, - doc: t.Optional[str] = None, + name: str | None = None, + doc: str | None = None, ) -> None: super().__init__(fget, doc=doc) self.__name__ = name or fget.__name__ @@ -145,14 +142,14 @@ class environ_property(_DictAccessorProperty[_TAccessorValue]): read_only = True - def lookup(self, obj: "Request") -> "WSGIEnvironment": + def lookup(self, obj: Request) -> WSGIEnvironment: return obj.environ class header_property(_DictAccessorProperty[_TAccessorValue]): """Like `environ_property` but for headers.""" - def lookup(self, obj: t.Union["Request", "Response"]) -> Headers: + def lookup(self, obj: Request | Response) -> Headers: return obj.headers @@ -221,7 +218,7 @@ def secure_filename(filename: str) -> str: filename = unicodedata.normalize("NFKD", filename) filename = filename.encode("ascii", "ignore").decode("ascii") - for sep in os.path.sep, os.path.altsep: + for sep in os.sep, os.path.altsep: if sep: filename = filename.replace(sep, " ") filename = str(_filename_ascii_strip_re.sub("", "_".join(filename.split()))).strip( @@ -242,8 +239,8 @@ def secure_filename(filename: str) -> str: def redirect( - location: str, code: int = 302, Response: t.Optional[t.Type["Response"]] = None -) -> "Response": + location: str, code: int = 302, Response: type[Response] | None = None +) -> Response: """Returns a response object (a WSGI application) that, if called, redirects the client to the target location. Supported codes are 301, 302, 303, 305, 307, and 308. 300 is not supported because @@ -264,24 +261,16 @@ def redirect( unspecified. """ if Response is None: - from .wrappers import Response # type: ignore + from .wrappers import Response - display_location = escape(location) - if isinstance(location, str): - # Safe conversion is necessary here as we might redirect - # to a broken URI scheme (for instance itms-services). - from .urls import iri_to_uri - - location = iri_to_uri(location, safe_conversion=True) - - response = Response( # type: ignore + html_location = escape(location) + response = Response( # type: ignore[misc] "\n" "\n" "Redirecting...\n" "

Redirecting...

\n" "

You should be redirected automatically to the target URL: " - f'{display_location}. If' - " not, click the link.\n", + f'{html_location}. If not, click the link.\n', code, mimetype="text/html", ) @@ -289,7 +278,7 @@ def redirect( return response -def append_slash_redirect(environ: "WSGIEnvironment", code: int = 308) -> "Response": +def append_slash_redirect(environ: WSGIEnvironment, code: int = 308) -> Response: """Redirect to the current URL with a slash appended. If the current URL is ``/user/42``, the redirect URL will be @@ -327,21 +316,19 @@ def append_slash_redirect(environ: "WSGIEnvironment", code: int = 308) -> "Respo def send_file( - path_or_file: t.Union[os.PathLike, str, t.IO[bytes]], - environ: "WSGIEnvironment", - mimetype: t.Optional[str] = None, + path_or_file: os.PathLike | str | t.IO[bytes], + environ: WSGIEnvironment, + mimetype: str | None = None, as_attachment: bool = False, - download_name: t.Optional[str] = None, + download_name: str | None = None, conditional: bool = True, - etag: t.Union[bool, str] = True, - last_modified: t.Optional[t.Union[datetime, int, float]] = None, - max_age: t.Optional[ - t.Union[int, t.Callable[[t.Optional[str]], t.Optional[int]]] - ] = None, + etag: bool | str = True, + last_modified: datetime | int | float | None = None, + max_age: None | (int | t.Callable[[str | None], int | None]) = None, use_x_sendfile: bool = False, - response_class: t.Optional[t.Type["Response"]] = None, - _root_path: t.Optional[t.Union[os.PathLike, str]] = None, -) -> "Response": + response_class: type[Response] | None = None, + _root_path: os.PathLike | str | None = None, +) -> Response: """Send the contents of a file to the client. The first argument can be a file path or a file-like object. Paths @@ -352,7 +339,7 @@ def send_file( Never pass file paths provided by a user. The path is assumed to be trusted, so a user could craft a path to access a file you didn't - intend. + intend. Use :func:`send_from_directory` to safely serve user-provided paths. If the WSGI server sets a ``file_wrapper`` in ``environ``, it is used, otherwise Werkzeug's built-in wrapper is used. Alternatively, @@ -419,10 +406,10 @@ def send_file( response_class = Response - path: t.Optional[str] = None - file: t.Optional[t.IO[bytes]] = None - size: t.Optional[int] = None - mtime: t.Optional[float] = None + path: str | None = None + file: t.IO[bytes] | None = None + size: int | None = None + mtime: float | None = None headers = Headers() if isinstance(path_or_file, (os.PathLike, str)) or hasattr( @@ -470,7 +457,8 @@ def send_file( except UnicodeEncodeError: simple = unicodedata.normalize("NFKD", download_name) simple = simple.encode("ascii", "ignore").decode("ascii") - quoted = url_quote(download_name, safe="") + # safe = RFC 5987 attr-char + quoted = quote(download_name, safe="!#$&+-.^_`|~") names = {"filename": simple, "filename*": f"UTF-8''{quoted}"} else: names = {"filename": download_name} @@ -547,11 +535,11 @@ def send_file( def send_from_directory( - directory: t.Union[os.PathLike, str], - path: t.Union[os.PathLike, str], - environ: "WSGIEnvironment", + directory: os.PathLike | str, + path: os.PathLike | str, + environ: WSGIEnvironment, **kwargs: t.Any, -) -> "Response": +) -> Response: """Send a file from within a directory using :func:`send_file`. This is a secure way to serve files from a folder, such as static @@ -562,9 +550,10 @@ def send_from_directory( If the final path does not point to an existing regular file, returns a 404 :exc:`~werkzeug.exceptions.NotFound` error. - :param directory: The directory that ``path`` must be located under. - :param path: The path to the file to send, relative to - ``directory``. + :param directory: The directory that ``path`` must be located under. This *must not* + be a value provided by the client, otherwise it becomes insecure. + :param path: The path to the file to send, relative to ``directory``. This is the + part of the path provided by the client, which is checked for security. :param environ: The WSGI environ for the current request. :param kwargs: Arguments to pass to :func:`send_file`. @@ -581,12 +570,8 @@ def send_from_directory( if "_root_path" in kwargs: path = os.path.join(kwargs["_root_path"], path) - try: - if not os.path.isfile(path): - raise NotFound() - except ValueError: - # path contains null byte on Python < 3.8 - raise NotFound() from None + if not os.path.isfile(path): + raise NotFound() return send_file(path, environ, **kwargs) diff --git a/src/werkzeug/wrappers/request.py b/src/werkzeug/wrappers/request.py index 57b739c..25b0916 100644 --- a/src/werkzeug/wrappers/request.py +++ b/src/werkzeug/wrappers/request.py @@ -1,6 +1,7 @@ +from __future__ import annotations + import functools import json -import typing import typing as t from io import BytesIO @@ -11,6 +12,8 @@ from ..datastructures import FileStorage from ..datastructures import ImmutableMultiDict from ..datastructures import iter_multi_items from ..datastructures import MultiDict +from ..exceptions import BadRequest +from ..exceptions import UnsupportedMediaType from ..formparser import default_stream_factory from ..formparser import FormDataParser from ..sansio.request import Request as _SansIORequest @@ -18,10 +21,8 @@ from ..utils import cached_property from ..utils import environ_property from ..wsgi import _get_server from ..wsgi import get_input_stream -from werkzeug.exceptions import BadRequest if t.TYPE_CHECKING: - import typing_extensions as te from _typeshed.wsgi import WSGIApplication from _typeshed.wsgi import WSGIEnvironment @@ -49,13 +50,19 @@ class Request(_SansIORequest): prevent consuming the form data in middleware, which would make it unavailable to the final application. + .. versionchanged:: 3.0 + The ``charset``, ``url_charset``, and ``encoding_errors`` parameters + were removed. + + .. versionchanged:: 2.1 + Old ``BaseRequest`` and mixin classes were removed. + .. versionchanged:: 2.1 Remove the ``disable_data_descriptor`` attribute. .. versionchanged:: 2.0 Combine ``BaseRequest`` and mixins into a single ``Request`` - class. Using the old classes is deprecated and will be removed - in Werkzeug 2.1. + class. .. versionchanged:: 0.5 Read-only mode is enforced with immutable classes for all data. @@ -67,10 +74,8 @@ class Request(_SansIORequest): #: parsing fails because more than the specified value is transmitted #: a :exc:`~werkzeug.exceptions.RequestEntityTooLarge` exception is raised. #: - #: Have a look at :doc:`/request_data` for more details. - #: #: .. versionadded:: 0.5 - max_content_length: t.Optional[int] = None + max_content_length: int | None = None #: the maximum form field size. This is forwarded to the form data #: parsing function (:func:`parse_form_data`). When set and the @@ -78,18 +83,23 @@ class Request(_SansIORequest): #: data in memory for post data is longer than the specified value a #: :exc:`~werkzeug.exceptions.RequestEntityTooLarge` exception is raised. #: - #: Have a look at :doc:`/request_data` for more details. - #: #: .. versionadded:: 0.5 - max_form_memory_size: t.Optional[int] = None + max_form_memory_size: int | None = None + + #: The maximum number of multipart parts to parse, passed to + #: :attr:`form_data_parser_class`. Parsing form data with more than this + #: many parts will raise :exc:`~.RequestEntityTooLarge`. + #: + #: .. versionadded:: 2.2.3 + max_form_parts = 1000 #: The form data parser that should be used. Can be replaced to customize #: the form date parsing. - form_data_parser_class: t.Type[FormDataParser] = FormDataParser + form_data_parser_class: type[FormDataParser] = FormDataParser #: The WSGI environment containing HTTP headers and information from #: the WSGI server. - environ: "WSGIEnvironment" + environ: WSGIEnvironment #: Set when creating the request object. If ``True``, reading from #: the request body will cause a ``RuntimeException``. Useful to @@ -98,7 +108,7 @@ class Request(_SansIORequest): def __init__( self, - environ: "WSGIEnvironment", + environ: WSGIEnvironment, populate_request: bool = True, shallow: bool = False, ) -> None: @@ -106,12 +116,8 @@ class Request(_SansIORequest): method=environ.get("REQUEST_METHOD", "GET"), scheme=environ.get("wsgi.url_scheme", "http"), server=_get_server(environ), - root_path=_wsgi_decoding_dance( - environ.get("SCRIPT_NAME") or "", self.charset, self.encoding_errors - ), - path=_wsgi_decoding_dance( - environ.get("PATH_INFO") or "", self.charset, self.encoding_errors - ), + root_path=_wsgi_decoding_dance(environ.get("SCRIPT_NAME") or ""), + path=_wsgi_decoding_dance(environ.get("PATH_INFO") or ""), query_string=environ.get("QUERY_STRING", "").encode("latin1"), headers=EnvironHeaders(environ), remote_addr=environ.get("REMOTE_ADDR"), @@ -123,7 +129,7 @@ class Request(_SansIORequest): self.environ["werkzeug.request"] = self @classmethod - def from_values(cls, *args: t.Any, **kwargs: t.Any) -> "Request": + def from_values(cls, *args: t.Any, **kwargs: t.Any) -> Request: """Create a new request object based on the values provided. If environ is given missing values are filled from there. This method is useful for small scripts when you need to simulate a request from an URL. @@ -143,8 +149,6 @@ class Request(_SansIORequest): """ from ..test import EnvironBuilder - charset = kwargs.pop("charset", cls.charset) - kwargs["charset"] = charset builder = EnvironBuilder(*args, **kwargs) try: return builder.get_request(cls) @@ -152,9 +156,7 @@ class Request(_SansIORequest): builder.close() @classmethod - def application( - cls, f: t.Callable[["Request"], "WSGIApplication"] - ) -> "WSGIApplication": + def application(cls, f: t.Callable[[Request], WSGIApplication]) -> WSGIApplication: """Decorate a function as responder that accepts the request as the last argument. This works like the :func:`responder` decorator but the function is passed the request object as the @@ -193,10 +195,10 @@ class Request(_SansIORequest): def _get_file_stream( self, - total_content_length: t.Optional[int], - content_type: t.Optional[str], - filename: t.Optional[str] = None, - content_length: t.Optional[int] = None, + total_content_length: int | None, + content_type: str | None, + filename: str | None = None, + content_length: int | None = None, ) -> t.IO[bytes]: """Called to get a stream for the file upload. @@ -240,12 +242,11 @@ class Request(_SansIORequest): .. versionadded:: 0.8 """ return self.form_data_parser_class( - self._get_file_stream, - self.charset, - self.encoding_errors, - self.max_form_memory_size, - self.max_content_length, - self.parameter_storage_class, + stream_factory=self._get_file_stream, + max_form_memory_size=self.max_form_memory_size, + max_content_length=self.max_content_length, + max_form_parts=self.max_form_parts, + cls=self.parameter_storage_class, ) def _load_form_data(self) -> None: @@ -304,7 +305,7 @@ class Request(_SansIORequest): for _key, value in iter_multi_items(files or ()): value.close() - def __enter__(self) -> "Request": + def __enter__(self) -> Request: return self def __exit__(self, exc_type, exc_value, tb) -> None: # type: ignore @@ -312,21 +313,30 @@ class Request(_SansIORequest): @cached_property def stream(self) -> t.IO[bytes]: - """ - If the incoming form data was not encoded with a known mimetype - the data is stored unmodified in this stream for consumption. Most - of the time it is a better idea to use :attr:`data` which will give - you that data as a string. The stream only returns the data once. + """The WSGI input stream, with safety checks. This stream can only be consumed + once. - Unlike :attr:`input_stream` this stream is properly guarded that you - can't accidentally read past the length of the input. Werkzeug will - internally always refer to this stream to read data which makes it - possible to wrap this object with a stream that does filtering. + Use :meth:`get_data` to get the full data as bytes or text. The :attr:`data` + attribute will contain the full bytes only if they do not represent form data. + The :attr:`form` attribute will contain the parsed form data in that case. + + Unlike :attr:`input_stream`, this stream guards against infinite streams or + reading past :attr:`content_length` or :attr:`max_content_length`. + + If ``max_content_length`` is set, it can be enforced on streams if + ``wsgi.input_terminated`` is set. Otherwise, an empty stream is returned. + + If the limit is reached before the underlying stream is exhausted (such as a + file that is too large, or an infinite stream), the remaining contents of the + stream cannot be read safely. Depending on how the server handles this, clients + may show a "connection reset" failure instead of seeing the 413 response. + + .. versionchanged:: 2.3 + Check ``max_content_length`` preemptively and while reading. .. versionchanged:: 0.9 - This stream is now always available but might be consumed by the - form parser later on. Previously the stream was only set if no - parsing happened. + The stream is always set (but may be consumed) even if form parsing was + accessed first. """ if self.shallow: raise RuntimeError( @@ -334,46 +344,51 @@ class Request(_SansIORequest): " from the input stream is disabled." ) - return get_input_stream(self.environ) + return get_input_stream( + self.environ, max_content_length=self.max_content_length + ) input_stream = environ_property[t.IO[bytes]]( "wsgi.input", - doc="""The WSGI input stream. + doc="""The raw WSGI input stream, without any safety checks. - In general it's a bad idea to use this one because you can - easily read past the boundary. Use the :attr:`stream` - instead.""", + This is dangerous to use. It does not guard against infinite streams or reading + past :attr:`content_length` or :attr:`max_content_length`. + + Use :attr:`stream` instead. + """, ) @cached_property def data(self) -> bytes: - """ - Contains the incoming request data as string in case it came with - a mimetype Werkzeug does not handle. + """The raw data read from :attr:`stream`. Will be empty if the request + represents form data. + + To get the raw data even if it represents form data, use :meth:`get_data`. """ return self.get_data(parse_form_data=True) - @typing.overload + @t.overload def get_data( # type: ignore self, cache: bool = True, - as_text: "te.Literal[False]" = False, + as_text: t.Literal[False] = False, parse_form_data: bool = False, ) -> bytes: ... - @typing.overload + @t.overload def get_data( self, cache: bool = True, - as_text: "te.Literal[True]" = ..., + as_text: t.Literal[True] = ..., parse_form_data: bool = False, ) -> str: ... def get_data( self, cache: bool = True, as_text: bool = False, parse_form_data: bool = False - ) -> t.Union[bytes, str]: + ) -> bytes | str: """This reads the buffered incoming data from the client into one bytes object. By default this is cached but that behavior can be changed by setting `cache` to `False`. @@ -406,11 +421,11 @@ class Request(_SansIORequest): if cache: self._cached_data = rv if as_text: - rv = rv.decode(self.charset, self.encoding_errors) + rv = rv.decode(errors="replace") return rv @cached_property - def form(self) -> "ImmutableMultiDict[str, str]": + def form(self) -> ImmutableMultiDict[str, str]: """The form parameters. By default an :class:`~werkzeug.datastructures.ImmutableMultiDict` is returned from this function. This can be changed by setting @@ -429,7 +444,7 @@ class Request(_SansIORequest): return self.form @cached_property - def values(self) -> "CombinedMultiDict[str, str]": + def values(self) -> CombinedMultiDict[str, str]: """A :class:`werkzeug.datastructures.CombinedMultiDict` that combines :attr:`args` and :attr:`form`. @@ -458,7 +473,7 @@ class Request(_SansIORequest): return CombinedMultiDict(args) @cached_property - def files(self) -> "ImmutableMultiDict[str, FileStorage]": + def files(self) -> ImmutableMultiDict[str, FileStorage]: """:class:`~werkzeug.datastructures.MultiDict` object containing all uploaded files. Each key in :attr:`files` is the name from the ````. Each value in :attr:`files` is a @@ -525,14 +540,17 @@ class Request(_SansIORequest): json_module = json @property - def json(self) -> t.Optional[t.Any]: + def json(self) -> t.Any | None: """The parsed JSON data if :attr:`mimetype` indicates JSON (:mimetype:`application/json`, see :attr:`is_json`). Calls :meth:`get_json` with default arguments. If the request content type is not ``application/json``, this - will raise a 400 Bad Request error. + will raise a 415 Unsupported Media Type error. + + .. versionchanged:: 2.3 + Raise a 415 error instead of 400. .. versionchanged:: 2.1 Raise a 400 error if the content type is incorrect. @@ -541,18 +559,30 @@ class Request(_SansIORequest): # Cached values for ``(silent=False, silent=True)``. Initialized # with sentinel values. - _cached_json: t.Tuple[t.Any, t.Any] = (Ellipsis, Ellipsis) + _cached_json: tuple[t.Any, t.Any] = (Ellipsis, Ellipsis) + + @t.overload + def get_json( + self, force: bool = ..., silent: t.Literal[False] = ..., cache: bool = ... + ) -> t.Any: + ... + + @t.overload + def get_json( + self, force: bool = ..., silent: bool = ..., cache: bool = ... + ) -> t.Any | None: + ... def get_json( self, force: bool = False, silent: bool = False, cache: bool = True - ) -> t.Optional[t.Any]: + ) -> t.Any | None: """Parse :attr:`data` as JSON. If the mimetype does not indicate JSON (:mimetype:`application/json`, see :attr:`is_json`), or parsing fails, :meth:`on_json_loading_failed` is called and its return value is used as the return value. By default this - raises a 400 Bad Request error. + raises a 415 Unsupported Media Type resp. :param force: Ignore the mimetype and always try to parse JSON. :param silent: Silence mimetype and parsing errors, and @@ -560,6 +590,9 @@ class Request(_SansIORequest): :param cache: Store the parsed JSON to return for subsequent calls. + .. versionchanged:: 2.3 + Raise a 415 error instead of 400. + .. versionchanged:: 2.1 Raise a 400 error if the content type is incorrect. """ @@ -595,7 +628,7 @@ class Request(_SansIORequest): return rv - def on_json_loading_failed(self, e: t.Optional[ValueError]) -> t.Any: + def on_json_loading_failed(self, e: ValueError | None) -> t.Any: """Called if :meth:`get_json` fails and isn't silenced. If this method returns a value, it is used as the return value @@ -604,11 +637,14 @@ class Request(_SansIORequest): :param e: If parsing failed, this is the exception. It will be ``None`` if the content type wasn't ``application/json``. + + .. versionchanged:: 2.3 + Raise a 415 error instead of 400. """ if e is not None: raise BadRequest(f"Failed to decode JSON object: {e}") - raise BadRequest( + raise UnsupportedMediaType( "Did not attempt to load JSON data because the request" " Content-Type was not 'application/json'." ) diff --git a/src/werkzeug/wrappers/response.py b/src/werkzeug/wrappers/response.py index 7e888cb..ee5c694 100644 --- a/src/werkzeug/wrappers/response.py +++ b/src/werkzeug/wrappers/response.py @@ -1,15 +1,15 @@ -import json -import typing -import typing as t -import warnings -from http import HTTPStatus +from __future__ import annotations + +import json +import typing as t +from http import HTTPStatus +from urllib.parse import urljoin -from .._internal import _to_bytes from ..datastructures import Headers from ..http import remove_entity_headers from ..sansio.response import Response as _SansIOResponse +from ..urls import _invalid_iri_to_uri from ..urls import iri_to_uri -from ..urls import url_join from ..utils import cached_property from ..wsgi import ClosingIterator from ..wsgi import get_current_url @@ -22,48 +22,20 @@ from werkzeug.http import parse_range_header from werkzeug.wsgi import _RangeWrapper if t.TYPE_CHECKING: - import typing_extensions as te from _typeshed.wsgi import StartResponse from _typeshed.wsgi import WSGIApplication from _typeshed.wsgi import WSGIEnvironment from .request import Request -def _warn_if_string(iterable: t.Iterable) -> None: - """Helper for the response objects to check if the iterable returned - to the WSGI server is not a string. - """ - if isinstance(iterable, str): - warnings.warn( - "Response iterable was set to a string. This will appear to" - " work but means that the server will send the data to the" - " client one character at a time. This is almost never" - " intended behavior, use 'response.data' to assign strings" - " to the response object.", - stacklevel=2, - ) - - -def _iter_encoded( - iterable: t.Iterable[t.Union[str, bytes]], charset: str -) -> t.Iterator[bytes]: +def _iter_encoded(iterable: t.Iterable[str | bytes]) -> t.Iterator[bytes]: for item in iterable: if isinstance(item, str): - yield item.encode(charset) + yield item.encode() else: yield item -def _clean_accept_ranges(accept_ranges: t.Union[bool, str]) -> str: - if accept_ranges is True: - return "bytes" - elif accept_ranges is False: - return "none" - elif isinstance(accept_ranges, str): - return accept_ranges - raise ValueError("Invalid accept_ranges value") - - class Response(_SansIOResponse): """Represents an outgoing WSGI HTTP response with body, status, and headers. Has properties and methods for using the functionality @@ -123,10 +95,12 @@ class Response(_SansIOResponse): checks. Use :func:`~werkzeug.utils.send_file` instead of setting this manually. + .. versionchanged:: 2.1 + Old ``BaseResponse`` and mixin classes were removed. + .. versionchanged:: 2.0 Combine ``BaseResponse`` and mixins into a single ``Response`` - class. Using the old classes is deprecated and will be removed - in Werkzeug 2.1. + class. .. versionchanged:: 0.5 The ``direct_passthrough`` parameter was added. @@ -165,22 +139,17 @@ class Response(_SansIOResponse): #: Do not set to a plain string or bytes, that will cause sending #: the response to be very inefficient as it will iterate one byte #: at a time. - response: t.Union[t.Iterable[str], t.Iterable[bytes]] + response: t.Iterable[str] | t.Iterable[bytes] def __init__( self, - response: t.Optional[ - t.Union[t.Iterable[bytes], bytes, t.Iterable[str], str] - ] = None, - status: t.Optional[t.Union[int, str, HTTPStatus]] = None, - headers: t.Optional[ - t.Union[ - t.Mapping[str, t.Union[str, int, t.Iterable[t.Union[str, int]]]], - t.Iterable[t.Tuple[str, t.Union[str, int]]], - ] - ] = None, - mimetype: t.Optional[str] = None, - content_type: t.Optional[str] = None, + response: t.Iterable[bytes] | bytes | t.Iterable[str] | str | None = None, + status: int | str | HTTPStatus | None = None, + headers: t.Mapping[str, str | t.Iterable[str]] + | t.Iterable[tuple[str, str]] + | None = None, + mimetype: str | None = None, + content_type: str | None = None, direct_passthrough: bool = False, ) -> None: super().__init__( @@ -196,7 +165,7 @@ class Response(_SansIOResponse): #: :func:`~werkzeug.utils.send_file` instead of setting this #: manually. self.direct_passthrough = direct_passthrough - self._on_close: t.List[t.Callable[[], t.Any]] = [] + self._on_close: list[t.Callable[[], t.Any]] = [] # we set the response after the headers so that if a class changes # the charset attribute, the data is set in the correct charset. @@ -227,8 +196,8 @@ class Response(_SansIOResponse): @classmethod def force_type( - cls, response: "Response", environ: t.Optional["WSGIEnvironment"] = None - ) -> "Response": + cls, response: Response, environ: WSGIEnvironment | None = None + ) -> Response: """Enforce that the WSGI response is a response object of the current type. Werkzeug will use the :class:`Response` internally in many situations like the exceptions. If you call :meth:`get_response` on an @@ -272,8 +241,8 @@ class Response(_SansIOResponse): @classmethod def from_app( - cls, app: "WSGIApplication", environ: "WSGIEnvironment", buffered: bool = False - ) -> "Response": + cls, app: WSGIApplication, environ: WSGIEnvironment, buffered: bool = False + ) -> Response: """Create a new response object from an application output. This works best if you pass it an application that returns a generator all the time. Sometimes applications may use the `write()` callable @@ -290,15 +259,15 @@ class Response(_SansIOResponse): return cls(*run_wsgi_app(app, environ, buffered)) - @typing.overload - def get_data(self, as_text: "te.Literal[False]" = False) -> bytes: + @t.overload + def get_data(self, as_text: t.Literal[False] = False) -> bytes: ... - @typing.overload - def get_data(self, as_text: "te.Literal[True]") -> str: + @t.overload + def get_data(self, as_text: t.Literal[True]) -> str: ... - def get_data(self, as_text: bool = False) -> t.Union[bytes, str]: + def get_data(self, as_text: bool = False) -> bytes | str: """The string representation of the response body. Whenever you call this property the response iterable is encoded and flattened. This can lead to unwanted behavior if you stream big data. @@ -315,23 +284,19 @@ class Response(_SansIOResponse): rv = b"".join(self.iter_encoded()) if as_text: - return rv.decode(self.charset) + return rv.decode() return rv - def set_data(self, value: t.Union[bytes, str]) -> None: + def set_data(self, value: bytes | str) -> None: """Sets a new string as response. The value must be a string or bytes. If a string is set it's encoded to the charset of the response (utf-8 by default). .. versionadded:: 0.9 """ - # if a string is set, it's encoded directly so that we - # can set the content length if isinstance(value, str): - value = value.encode(self.charset) - else: - value = bytes(value) + value = value.encode() self.response = [value] if self.automatically_set_content_length: self.headers["Content-Length"] = str(len(value)) @@ -342,7 +307,7 @@ class Response(_SansIOResponse): doc="A descriptor that calls :meth:`get_data` and :meth:`set_data`.", ) - def calculate_content_length(self) -> t.Optional[int]: + def calculate_content_length(self) -> int | None: """Returns the content length if available or `None` otherwise.""" try: self._ensure_sequence() @@ -398,12 +363,10 @@ class Response(_SansIOResponse): value of this method is used as application iterator unless :attr:`direct_passthrough` was activated. """ - if __debug__: - _warn_if_string(self.response) # Encode in a separate function so that self.response is fetched # early. This allows us to wrap the response with the return # value from get_app_iter or iter_encoded. - return _iter_encoded(self.response, self.charset) + return _iter_encoded(self.response) @property def is_streamed(self) -> bool: @@ -439,11 +402,11 @@ class Response(_SansIOResponse): Can now be used in a with statement. """ if hasattr(self.response, "close"): - self.response.close() # type: ignore + self.response.close() for func in self._on_close: func() - def __enter__(self) -> "Response": + def __enter__(self) -> Response: return self def __exit__(self, exc_type, exc_value, tb): # type: ignore @@ -463,8 +426,7 @@ class Response(_SansIOResponse): Removed the ``no_etag`` parameter. .. versionchanged:: 2.0 - An ``ETag`` header is added, the ``no_etag`` parameter is - deprecated and will be removed in Werkzeug 2.1. + An ``ETag`` header is always added. .. versionchanged:: 0.6 The ``Content-Length`` header is set. @@ -475,7 +437,7 @@ class Response(_SansIOResponse): self.headers["Content-Length"] = str(sum(map(len, self.response))) self.add_etag() - def get_wsgi_headers(self, environ: "WSGIEnvironment") -> Headers: + def get_wsgi_headers(self, environ: WSGIEnvironment) -> Headers: """This is automatically called right before the response is started and returns headers modified for the given environment. It returns a copy of the headers from the response with some modifications applied @@ -500,9 +462,9 @@ class Response(_SansIOResponse): object. """ headers = Headers(self.headers) - location: t.Optional[str] = None - content_location: t.Optional[str] = None - content_length: t.Optional[t.Union[str, int]] = None + location: str | None = None + content_location: str | None = None + content_length: str | int | None = None status = self.status_code # iterate over the headers to find all values in one go. Because @@ -517,24 +479,19 @@ class Response(_SansIOResponse): elif ikey == "content-length": content_length = value - # make sure the location header is an absolute URL if location is not None: - old_location = location - if isinstance(location, str): - # Safe conversion is necessary here as we might redirect - # to a broken URI scheme (for instance itms-services). - location = iri_to_uri(location, safe_conversion=True) + location = _invalid_iri_to_uri(location) if self.autocorrect_location_header: + # Make the location header an absolute URL. current_url = get_current_url(environ, strip_querystring=True) - if isinstance(current_url, str): - current_url = iri_to_uri(current_url) - location = url_join(current_url, location) - if location != old_location: - headers["Location"] = location + current_url = iri_to_uri(current_url) + location = urljoin(current_url, location) + + headers["Location"] = location # make sure the content location is a URL - if content_location is not None and isinstance(content_location, str): + if content_location is not None: headers["Content-Location"] = iri_to_uri(content_location) if 100 <= status < 200 or status == 204: @@ -557,18 +514,12 @@ class Response(_SansIOResponse): and status not in (204, 304) and not (100 <= status < 200) ): - try: - content_length = sum(len(_to_bytes(x, "ascii")) for x in self.response) - except UnicodeError: - # Something other than bytes, can't safely figure out - # the length of the response. - pass - else: - headers["Content-Length"] = str(content_length) + content_length = sum(len(x) for x in self.iter_encoded()) + headers["Content-Length"] = str(content_length) return headers - def get_app_iter(self, environ: "WSGIEnvironment") -> t.Iterable[bytes]: + def get_app_iter(self, environ: WSGIEnvironment) -> t.Iterable[bytes]: """Returns the application iterator for the given environ. Depending on the request method and the current status code the return value might be an empty response rather than the one from the response. @@ -590,16 +541,14 @@ class Response(_SansIOResponse): ): iterable: t.Iterable[bytes] = () elif self.direct_passthrough: - if __debug__: - _warn_if_string(self.response) return self.response # type: ignore else: iterable = self.iter_encoded() return ClosingIterator(iterable, self.close) def get_wsgi_response( - self, environ: "WSGIEnvironment" - ) -> t.Tuple[t.Iterable[bytes], str, t.List[t.Tuple[str, str]]]: + self, environ: WSGIEnvironment + ) -> tuple[t.Iterable[bytes], str, list[tuple[str, str]]]: """Returns the final WSGI response as tuple. The first item in the tuple is the application iterator, the second the status and the third the list of headers. The response returned is created @@ -617,7 +566,7 @@ class Response(_SansIOResponse): return app_iter, self.status, headers.to_wsgi_list() def __call__( - self, environ: "WSGIEnvironment", start_response: "StartResponse" + self, environ: WSGIEnvironment, start_response: StartResponse ) -> t.Iterable[bytes]: """Process this response as WSGI application. @@ -637,7 +586,7 @@ class Response(_SansIOResponse): json_module = json @property - def json(self) -> t.Optional[t.Any]: + def json(self) -> t.Any | None: """The parsed JSON data if :attr:`mimetype` indicates JSON (:mimetype:`application/json`, see :attr:`is_json`). @@ -645,7 +594,15 @@ class Response(_SansIOResponse): """ return self.get_json() - def get_json(self, force: bool = False, silent: bool = False) -> t.Optional[t.Any]: + @t.overload + def get_json(self, force: bool = ..., silent: t.Literal[False] = ...) -> t.Any: + ... + + @t.overload + def get_json(self, force: bool = ..., silent: bool = ...) -> t.Any | None: + ... + + def get_json(self, force: bool = False, silent: bool = False) -> t.Any | None: """Parse :attr:`data` as JSON. Useful during testing. If the mimetype does not indicate JSON @@ -674,7 +631,7 @@ class Response(_SansIOResponse): # Stream @cached_property - def stream(self) -> "ResponseStream": + def stream(self) -> ResponseStream: """The response iterable as write-only stream.""" return ResponseStream(self) @@ -683,7 +640,7 @@ class Response(_SansIOResponse): if self.status_code == 206: self.response = _RangeWrapper(self.response, start, length) # type: ignore - def _is_range_request_processable(self, environ: "WSGIEnvironment") -> bool: + def _is_range_request_processable(self, environ: WSGIEnvironment) -> bool: """Return ``True`` if `Range` header is present and if underlying resource is considered unchanged when compared with `If-Range` header. """ @@ -700,9 +657,9 @@ class Response(_SansIOResponse): def _process_range_request( self, - environ: "WSGIEnvironment", - complete_length: t.Optional[int] = None, - accept_ranges: t.Optional[t.Union[bool, str]] = None, + environ: WSGIEnvironment, + complete_length: int | None, + accept_ranges: bool | str, ) -> bool: """Handle Range Request related headers (RFC7233). If `Accept-Ranges` header is valid, and Range Request is processable, we set the headers @@ -720,13 +677,16 @@ class Response(_SansIOResponse): from ..exceptions import RequestedRangeNotSatisfiable if ( - accept_ranges is None + not accept_ranges or complete_length is None or complete_length == 0 or not self._is_range_request_processable(environ) ): return False + if accept_ranges is True: + accept_ranges = "bytes" + parsed_range = parse_range_header(environ.get("HTTP_RANGE")) if parsed_range is None: @@ -739,7 +699,7 @@ class Response(_SansIOResponse): raise RequestedRangeNotSatisfiable(complete_length) content_length = range_tuple[1] - range_tuple[0] - self.headers["Content-Length"] = content_length + self.headers["Content-Length"] = str(content_length) self.headers["Accept-Ranges"] = accept_ranges self.content_range = content_range_header # type: ignore self.status_code = 206 @@ -748,10 +708,10 @@ class Response(_SansIOResponse): def make_conditional( self, - request_or_environ: t.Union["WSGIEnvironment", "Request"], - accept_ranges: t.Union[bool, str] = False, - complete_length: t.Optional[int] = None, - ) -> "Response": + request_or_environ: WSGIEnvironment | Request, + accept_ranges: bool | str = False, + complete_length: int | None = None, + ) -> Response: """Make the response conditional to the request. This method works best if an etag was defined for the response already. The `add_etag` method can be used to do that. If called without etag just the date @@ -777,8 +737,7 @@ class Response(_SansIOResponse): :param accept_ranges: This parameter dictates the value of `Accept-Ranges` header. If ``False`` (default), the header is not set. If ``True``, it will be set - to ``"bytes"``. If ``None``, it will be set to - ``"none"``. If it's a string, it will use this + to ``"bytes"``. If it's a string, it will use this value. :param complete_length: Will be used only in valid Range Requests. It will set `Content-Range` complete length @@ -800,7 +759,6 @@ class Response(_SansIOResponse): # wsgiref. if "date" not in self.headers: self.headers["Date"] = http_date() - accept_ranges = _clean_accept_ranges(accept_ranges) is206 = self._process_range_request(environ, complete_length, accept_ranges) if not is206 and not is_resource_modified( environ, @@ -818,7 +776,7 @@ class Response(_SansIOResponse): ): length = self.calculate_content_length() if length is not None: - self.headers["Content-Length"] = length + self.headers["Content-Length"] = str(length) return self def add_etag(self, overwrite: bool = False, weak: bool = False) -> None: @@ -874,4 +832,4 @@ class ResponseStream: @property def encoding(self) -> str: - return self.response.charset + return "utf-8" diff --git a/src/werkzeug/wsgi.py b/src/werkzeug/wsgi.py index 24ece0b..01d40af 100644 --- a/src/werkzeug/wsgi.py +++ b/src/werkzeug/wsgi.py @@ -1,28 +1,21 @@ +from __future__ import annotations + import io -import re import typing as t -import warnings from functools import partial from functools import update_wrapper -from itertools import chain -from ._internal import _make_encode_wrapper -from ._internal import _to_bytes -from ._internal import _to_str +from .exceptions import ClientDisconnected +from .exceptions import RequestEntityTooLarge from .sansio import utils as _sansio_utils from .sansio.utils import host_is_trusted # noqa: F401 # Imported as part of API -from .urls import _URLTuple -from .urls import uri_to_iri -from .urls import url_join -from .urls import url_parse -from .urls import url_quote if t.TYPE_CHECKING: from _typeshed.wsgi import WSGIApplication from _typeshed.wsgi import WSGIEnvironment -def responder(f: t.Callable[..., "WSGIApplication"]) -> "WSGIApplication": +def responder(f: t.Callable[..., WSGIApplication]) -> WSGIApplication: """Marks a function as responder. Decorate a function with it and it will automatically call the return value as WSGI application. @@ -36,11 +29,11 @@ def responder(f: t.Callable[..., "WSGIApplication"]) -> "WSGIApplication": def get_current_url( - environ: "WSGIEnvironment", + environ: WSGIEnvironment, root_only: bool = False, strip_querystring: bool = False, host_only: bool = False, - trusted_hosts: t.Optional[t.Iterable[str]] = None, + trusted_hosts: t.Iterable[str] | None = None, ) -> str: """Recreate the URL for a request from the parts in a WSGI environment. @@ -74,15 +67,15 @@ def get_current_url( def _get_server( - environ: "WSGIEnvironment", -) -> t.Optional[t.Tuple[str, t.Optional[int]]]: + environ: WSGIEnvironment, +) -> tuple[str, int | None] | None: name = environ.get("SERVER_NAME") if name is None: return None try: - port: t.Optional[int] = int(environ.get("SERVER_PORT", None)) + port: int | None = int(environ.get("SERVER_PORT", None)) except (TypeError, ValueError): # unix socket port = None @@ -91,7 +84,7 @@ def _get_server( def get_host( - environ: "WSGIEnvironment", trusted_hosts: t.Optional[t.Iterable[str]] = None + environ: WSGIEnvironment, trusted_hosts: t.Iterable[str] | None = None ) -> str: """Return the host for the given WSGI environment. @@ -118,337 +111,101 @@ def get_host( ) -def get_content_length(environ: "WSGIEnvironment") -> t.Optional[int]: - """Returns the content length from the WSGI environment as - integer. If it's not available or chunked transfer encoding is used, - ``None`` is returned. +def get_content_length(environ: WSGIEnvironment) -> int | None: + """Return the ``Content-Length`` header value as an int. If the header is not given + or the ``Transfer-Encoding`` header is ``chunked``, ``None`` is returned to indicate + a streaming request. If the value is not an integer, or negative, 0 is returned. + + :param environ: The WSGI environ to get the content length from. .. versionadded:: 0.9 - - :param environ: the WSGI environ to fetch the content length from. """ return _sansio_utils.get_content_length( http_content_length=environ.get("CONTENT_LENGTH"), - http_transfer_encoding=environ.get("HTTP_TRANSFER_ENCODING", ""), + http_transfer_encoding=environ.get("HTTP_TRANSFER_ENCODING"), ) def get_input_stream( - environ: "WSGIEnvironment", safe_fallback: bool = True + environ: WSGIEnvironment, + safe_fallback: bool = True, + max_content_length: int | None = None, ) -> t.IO[bytes]: - """Returns the input stream from the WSGI environment and wraps it - in the most sensible way possible. The stream returned is not the - raw WSGI stream in most cases but one that is safe to read from - without taking into account the content length. + """Return the WSGI input stream, wrapped so that it may be read safely without going + past the ``Content-Length`` header value or ``max_content_length``. - If content length is not set, the stream will be empty for safety reasons. - If the WSGI server supports chunked or infinite streams, it should set - the ``wsgi.input_terminated`` value in the WSGI environ to indicate that. + If ``Content-Length`` exceeds ``max_content_length``, a + :exc:`RequestEntityTooLarge`` ``413 Content Too Large`` error is raised. + + If the WSGI server sets ``environ["wsgi.input_terminated"]``, it indicates that the + server handles terminating the stream, so it is safe to read directly. For example, + a server that knows how to handle chunked requests safely would set this. + + If ``max_content_length`` is set, it can be enforced on streams if + ``wsgi.input_terminated`` is set. Otherwise, an empty stream is returned unless the + user explicitly disables this safe fallback. + + If the limit is reached before the underlying stream is exhausted (such as a file + that is too large, or an infinite stream), the remaining contents of the stream + cannot be read safely. Depending on how the server handles this, clients may show a + "connection reset" failure instead of seeing the 413 response. + + :param environ: The WSGI environ containing the stream. + :param safe_fallback: Return an empty stream when ``Content-Length`` is not set. + Disabling this allows infinite streams, which can be a denial-of-service risk. + :param max_content_length: The maximum length that content-length or streaming + requests may not exceed. + + .. versionchanged:: 2.3.2 + ``max_content_length`` is only applied to streaming requests if the server sets + ``wsgi.input_terminated``. + + .. versionchanged:: 2.3 + Check ``max_content_length`` and raise an error if it is exceeded. .. versionadded:: 0.9 - - :param environ: the WSGI environ to fetch the stream from. - :param safe_fallback: use an empty stream as a safe fallback when the - content length is not set. Disabling this allows infinite streams, - which can be a denial-of-service risk. """ stream = t.cast(t.IO[bytes], environ["wsgi.input"]) content_length = get_content_length(environ) - # A wsgi extension that tells us if the input is terminated. In - # that case we return the stream unchanged as we know we can safely - # read it until the end. - if environ.get("wsgi.input_terminated"): + if content_length is not None and max_content_length is not None: + if content_length > max_content_length: + raise RequestEntityTooLarge() + + # A WSGI server can set this to indicate that it terminates the input stream. In + # that case the stream is safe without wrapping, or can enforce a max length. + if "wsgi.input_terminated" in environ: + if max_content_length is not None: + # If this is moved above, it can cause the stream to hang if a read attempt + # is made when the client sends no data. For example, the development server + # does not handle buffering except for chunked encoding. + return t.cast( + t.IO[bytes], LimitedStream(stream, max_content_length, is_max=True) + ) + return stream - # If the request doesn't specify a content length, returning the stream is - # potentially dangerous because it could be infinite, malicious or not. If - # safe_fallback is true, return an empty stream instead for safety. + # No limit given, return an empty stream unless the user explicitly allows the + # potentially infinite stream. An infinite stream is dangerous if it's not expected, + # as it can tie up a worker indefinitely. if content_length is None: return io.BytesIO() if safe_fallback else stream - # Otherwise limit the stream to the content length return t.cast(t.IO[bytes], LimitedStream(stream, content_length)) -def get_query_string(environ: "WSGIEnvironment") -> str: - """Returns the ``QUERY_STRING`` from the WSGI environment. This also - takes care of the WSGI decoding dance. The string returned will be - restricted to ASCII characters. - - :param environ: WSGI environment to get the query string from. - - .. deprecated:: 2.2 - Will be removed in Werkzeug 2.3. - - .. versionadded:: 0.9 - """ - warnings.warn( - "'get_query_string' is deprecated and will be removed in Werkzeug 2.3.", - DeprecationWarning, - stacklevel=2, - ) - qs = environ.get("QUERY_STRING", "").encode("latin1") - # QUERY_STRING really should be ascii safe but some browsers - # will send us some unicode stuff (I am looking at you IE). - # In that case we want to urllib quote it badly. - return url_quote(qs, safe=":&%=+$!*'(),") - - -def get_path_info( - environ: "WSGIEnvironment", charset: str = "utf-8", errors: str = "replace" -) -> str: - """Return the ``PATH_INFO`` from the WSGI environment and decode it - unless ``charset`` is ``None``. +def get_path_info(environ: WSGIEnvironment) -> str: + """Return ``PATH_INFO`` from the WSGI environment. :param environ: WSGI environment to get the path from. - :param charset: The charset for the path info, or ``None`` if no - decoding should be performed. - :param errors: The decoding error handling. + + .. versionchanged:: 3.0 + The ``charset`` and ``errors`` parameters were removed. .. versionadded:: 0.9 """ - path = environ.get("PATH_INFO", "").encode("latin1") - return _to_str(path, charset, errors, allow_none_charset=True) # type: ignore - - -def get_script_name( - environ: "WSGIEnvironment", charset: str = "utf-8", errors: str = "replace" -) -> str: - """Return the ``SCRIPT_NAME`` from the WSGI environment and decode - it unless `charset` is set to ``None``. - - :param environ: WSGI environment to get the path from. - :param charset: The charset for the path, or ``None`` if no decoding - should be performed. - :param errors: The decoding error handling. - - .. deprecated:: 2.2 - Will be removed in Werkzeug 2.3. - - .. versionadded:: 0.9 - """ - warnings.warn( - "'get_script_name' is deprecated and will be removed in Werkzeug 2.3.", - DeprecationWarning, - stacklevel=2, - ) - path = environ.get("SCRIPT_NAME", "").encode("latin1") - return _to_str(path, charset, errors, allow_none_charset=True) # type: ignore - - -def pop_path_info( - environ: "WSGIEnvironment", charset: str = "utf-8", errors: str = "replace" -) -> t.Optional[str]: - """Removes and returns the next segment of `PATH_INFO`, pushing it onto - `SCRIPT_NAME`. Returns `None` if there is nothing left on `PATH_INFO`. - - If the `charset` is set to `None` bytes are returned. - - If there are empty segments (``'/foo//bar``) these are ignored but - properly pushed to the `SCRIPT_NAME`: - - >>> env = {'SCRIPT_NAME': '/foo', 'PATH_INFO': '/a/b'} - >>> pop_path_info(env) - 'a' - >>> env['SCRIPT_NAME'] - '/foo/a' - >>> pop_path_info(env) - 'b' - >>> env['SCRIPT_NAME'] - '/foo/a/b' - - .. deprecated:: 2.2 - Will be removed in Werkzeug 2.3. - - .. versionadded:: 0.5 - - .. versionchanged:: 0.9 - The path is now decoded and a charset and encoding - parameter can be provided. - - :param environ: the WSGI environment that is modified. - :param charset: The ``encoding`` parameter passed to - :func:`bytes.decode`. - :param errors: The ``errors`` paramater passed to - :func:`bytes.decode`. - """ - warnings.warn( - "'pop_path_info' is deprecated and will be removed in Werkzeug 2.3.", - DeprecationWarning, - stacklevel=2, - ) - - path = environ.get("PATH_INFO") - if not path: - return None - - script_name = environ.get("SCRIPT_NAME", "") - - # shift multiple leading slashes over - old_path = path - path = path.lstrip("/") - if path != old_path: - script_name += "/" * (len(old_path) - len(path)) - - if "/" not in path: - environ["PATH_INFO"] = "" - environ["SCRIPT_NAME"] = script_name + path - rv = path.encode("latin1") - else: - segment, path = path.split("/", 1) - environ["PATH_INFO"] = f"/{path}" - environ["SCRIPT_NAME"] = script_name + segment - rv = segment.encode("latin1") - - return _to_str(rv, charset, errors, allow_none_charset=True) # type: ignore - - -def peek_path_info( - environ: "WSGIEnvironment", charset: str = "utf-8", errors: str = "replace" -) -> t.Optional[str]: - """Returns the next segment on the `PATH_INFO` or `None` if there - is none. Works like :func:`pop_path_info` without modifying the - environment: - - >>> env = {'SCRIPT_NAME': '/foo', 'PATH_INFO': '/a/b'} - >>> peek_path_info(env) - 'a' - >>> peek_path_info(env) - 'a' - - If the `charset` is set to `None` bytes are returned. - - .. deprecated:: 2.2 - Will be removed in Werkzeug 2.3. - - .. versionadded:: 0.5 - - .. versionchanged:: 0.9 - The path is now decoded and a charset and encoding - parameter can be provided. - - :param environ: the WSGI environment that is checked. - """ - warnings.warn( - "'peek_path_info' is deprecated and will be removed in Werkzeug 2.3.", - DeprecationWarning, - stacklevel=2, - ) - - segments = environ.get("PATH_INFO", "").lstrip("/").split("/", 1) - if segments: - return _to_str( # type: ignore - segments[0].encode("latin1"), charset, errors, allow_none_charset=True - ) - return None - - -def extract_path_info( - environ_or_baseurl: t.Union[str, "WSGIEnvironment"], - path_or_url: t.Union[str, _URLTuple], - charset: str = "utf-8", - errors: str = "werkzeug.url_quote", - collapse_http_schemes: bool = True, -) -> t.Optional[str]: - """Extracts the path info from the given URL (or WSGI environment) and - path. The path info returned is a string. The URLs might also be IRIs. - - If the path info could not be determined, `None` is returned. - - Some examples: - - >>> extract_path_info('http://example.com/app', '/app/hello') - '/hello' - >>> extract_path_info('http://example.com/app', - ... 'https://example.com/app/hello') - '/hello' - >>> extract_path_info('http://example.com/app', - ... 'https://example.com/app/hello', - ... collapse_http_schemes=False) is None - True - - Instead of providing a base URL you can also pass a WSGI environment. - - :param environ_or_baseurl: a WSGI environment dict, a base URL or - base IRI. This is the root of the - application. - :param path_or_url: an absolute path from the server root, a - relative path (in which case it's the path info) - or a full URL. - :param charset: the charset for byte data in URLs - :param errors: the error handling on decode - :param collapse_http_schemes: if set to `False` the algorithm does - not assume that http and https on the - same server point to the same - resource. - - .. deprecated:: 2.2 - Will be removed in Werkzeug 2.3. - - .. versionchanged:: 0.15 - The ``errors`` parameter defaults to leaving invalid bytes - quoted instead of replacing them. - - .. versionadded:: 0.6 - - """ - warnings.warn( - "'extract_path_info' is deprecated and will be removed in Werkzeug 2.3.", - DeprecationWarning, - stacklevel=2, - ) - - def _normalize_netloc(scheme: str, netloc: str) -> str: - parts = netloc.split("@", 1)[-1].split(":", 1) - port: t.Optional[str] - - if len(parts) == 2: - netloc, port = parts - if (scheme == "http" and port == "80") or ( - scheme == "https" and port == "443" - ): - port = None - else: - netloc = parts[0] - port = None - - if port is not None: - netloc += f":{port}" - - return netloc - - # make sure whatever we are working on is a IRI and parse it - path = uri_to_iri(path_or_url, charset, errors) - if isinstance(environ_or_baseurl, dict): - environ_or_baseurl = get_current_url(environ_or_baseurl, root_only=True) - base_iri = uri_to_iri(environ_or_baseurl, charset, errors) - base_scheme, base_netloc, base_path = url_parse(base_iri)[:3] - cur_scheme, cur_netloc, cur_path = url_parse(url_join(base_iri, path))[:3] - - # normalize the network location - base_netloc = _normalize_netloc(base_scheme, base_netloc) - cur_netloc = _normalize_netloc(cur_scheme, cur_netloc) - - # is that IRI even on a known HTTP scheme? - if collapse_http_schemes: - for scheme in base_scheme, cur_scheme: - if scheme not in ("http", "https"): - return None - else: - if not (base_scheme in ("http", "https") and base_scheme == cur_scheme): - return None - - # are the netlocs compatible? - if base_netloc != cur_netloc: - return None - - # are we below the application path? - base_path = base_path.rstrip("/") - if not cur_path.startswith(base_path): - return None - - return f"/{cur_path[len(base_path) :].lstrip('/')}" + path: bytes = environ.get("PATH_INFO", "").encode("latin1") + return path.decode(errors="replace") class ClosingIterator: @@ -476,9 +233,8 @@ class ClosingIterator: def __init__( self, iterable: t.Iterable[bytes], - callbacks: t.Optional[ - t.Union[t.Callable[[], None], t.Iterable[t.Callable[[], None]]] - ] = None, + callbacks: None + | (t.Callable[[], None] | t.Iterable[t.Callable[[], None]]) = None, ) -> None: iterator = iter(iterable) self._next = t.cast(t.Callable[[], bytes], partial(next, iterator)) @@ -493,7 +249,7 @@ class ClosingIterator: callbacks.insert(0, iterable_close) self._callbacks = callbacks - def __iter__(self) -> "ClosingIterator": + def __iter__(self) -> ClosingIterator: return self def __next__(self) -> bytes: @@ -505,7 +261,7 @@ class ClosingIterator: def wrap_file( - environ: "WSGIEnvironment", file: t.IO[bytes], buffer_size: int = 8192 + environ: WSGIEnvironment, file: t.IO[bytes], buffer_size: int = 8192 ) -> t.Iterable[bytes]: """Wraps a file. This uses the WSGI server's file wrapper if available or otherwise the generic :class:`FileWrapper`. @@ -564,12 +320,12 @@ class FileWrapper: if hasattr(self.file, "seek"): self.file.seek(*args) - def tell(self) -> t.Optional[int]: + def tell(self) -> int | None: if hasattr(self.file, "tell"): return self.file.tell() return None - def __iter__(self) -> "FileWrapper": + def __iter__(self) -> FileWrapper: return self def __next__(self) -> bytes: @@ -598,9 +354,9 @@ class _RangeWrapper: def __init__( self, - iterable: t.Union[t.Iterable[bytes], t.IO[bytes]], + iterable: t.Iterable[bytes] | t.IO[bytes], start_byte: int = 0, - byte_range: t.Optional[int] = None, + byte_range: int | None = None, ): self.iterable = iter(iterable) self.byte_range = byte_range @@ -611,12 +367,10 @@ class _RangeWrapper: self.end_byte = start_byte + byte_range self.read_length = 0 - self.seekable = ( - hasattr(iterable, "seekable") and iterable.seekable() # type: ignore - ) + self.seekable = hasattr(iterable, "seekable") and iterable.seekable() self.end_reached = False - def __iter__(self) -> "_RangeWrapper": + def __iter__(self) -> _RangeWrapper: return self def _next_chunk(self) -> bytes: @@ -628,7 +382,7 @@ class _RangeWrapper: self.end_reached = True raise - def _first_iteration(self) -> t.Tuple[t.Optional[bytes], int]: + def _first_iteration(self) -> tuple[bytes | None, int]: chunk = None if self.seekable: self.iterable.seek(self.start_byte) # type: ignore @@ -665,356 +419,177 @@ class _RangeWrapper: def close(self) -> None: if hasattr(self.iterable, "close"): - self.iterable.close() # type: ignore + self.iterable.close() -def _make_chunk_iter( - stream: t.Union[t.Iterable[bytes], t.IO[bytes]], - limit: t.Optional[int], - buffer_size: int, -) -> t.Iterator[bytes]: - """Helper for the line and chunk iter functions.""" - if isinstance(stream, (bytes, bytearray, str)): - raise TypeError( - "Passed a string or byte object instead of true iterator or stream." - ) - if not hasattr(stream, "read"): - for item in stream: - if item: - yield item - return - stream = t.cast(t.IO[bytes], stream) - if not isinstance(stream, LimitedStream) and limit is not None: - stream = t.cast(t.IO[bytes], LimitedStream(stream, limit)) - _read = stream.read - while True: - item = _read(buffer_size) - if not item: - break - yield item +class LimitedStream(io.RawIOBase): + """Wrap a stream so that it doesn't read more than a given limit. This is used to + limit ``wsgi.input`` to the ``Content-Length`` header value or + :attr:`.Request.max_content_length`. + When attempting to read after the limit has been reached, :meth:`on_exhausted` is + called. When the limit is a maximum, this raises :exc:`.RequestEntityTooLarge`. -def make_line_iter( - stream: t.Union[t.Iterable[bytes], t.IO[bytes]], - limit: t.Optional[int] = None, - buffer_size: int = 10 * 1024, - cap_at_buffer: bool = False, -) -> t.Iterator[bytes]: - """Safely iterates line-based over an input stream. If the input stream - is not a :class:`LimitedStream` the `limit` parameter is mandatory. + If reading from the stream returns zero bytes or raises an error, + :meth:`on_disconnect` is called, which raises :exc:`.ClientDisconnected`. When the + limit is a maximum and zero bytes were read, no error is raised, since it may be the + end of the stream. - This uses the stream's :meth:`~file.read` method internally as opposite - to the :meth:`~file.readline` method that is unsafe and can only be used - in violation of the WSGI specification. The same problem applies to the - `__iter__` function of the input stream which calls :meth:`~file.readline` - without arguments. + If the limit is reached before the underlying stream is exhausted (such as a file + that is too large, or an infinite stream), the remaining contents of the stream + cannot be read safely. Depending on how the server handles this, clients may show a + "connection reset" failure instead of seeing the 413 response. - If you need line-by-line processing it's strongly recommended to iterate - over the input stream using this helper function. + :param stream: The stream to read from. Must be a readable binary IO object. + :param limit: The limit in bytes to not read past. Should be either the + ``Content-Length`` header value or ``request.max_content_length``. + :param is_max: Whether the given ``limit`` is ``request.max_content_length`` instead + of the ``Content-Length`` header value. This changes how exhausted and + disconnect events are handled. - .. versionchanged:: 0.8 - This function now ensures that the limit was reached. + .. versionchanged:: 2.3 + Handle ``max_content_length`` differently than ``Content-Length``. - .. versionadded:: 0.9 - added support for iterators as input stream. - - .. versionadded:: 0.11.10 - added support for the `cap_at_buffer` parameter. - - :param stream: the stream or iterate to iterate over. - :param limit: the limit in bytes for the stream. (Usually - content length. Not necessary if the `stream` - is a :class:`LimitedStream`. - :param buffer_size: The optional buffer size. - :param cap_at_buffer: if this is set chunks are split if they are longer - than the buffer size. Internally this is implemented - that the buffer size might be exhausted by a factor - of two however. - """ - _iter = _make_chunk_iter(stream, limit, buffer_size) - - first_item = next(_iter, "") - if not first_item: - return - - s = _make_encode_wrapper(first_item) - empty = t.cast(bytes, s("")) - cr = t.cast(bytes, s("\r")) - lf = t.cast(bytes, s("\n")) - crlf = t.cast(bytes, s("\r\n")) - - _iter = t.cast(t.Iterator[bytes], chain((first_item,), _iter)) - - def _iter_basic_lines() -> t.Iterator[bytes]: - _join = empty.join - buffer: t.List[bytes] = [] - while True: - new_data = next(_iter, "") - if not new_data: - break - new_buf: t.List[bytes] = [] - buf_size = 0 - for item in t.cast( - t.Iterator[bytes], chain(buffer, new_data.splitlines(True)) - ): - new_buf.append(item) - buf_size += len(item) - if item and item[-1:] in crlf: - yield _join(new_buf) - new_buf = [] - elif cap_at_buffer and buf_size >= buffer_size: - rv = _join(new_buf) - while len(rv) >= buffer_size: - yield rv[:buffer_size] - rv = rv[buffer_size:] - new_buf = [rv] - buffer = new_buf - if buffer: - yield _join(buffer) - - # This hackery is necessary to merge 'foo\r' and '\n' into one item - # of 'foo\r\n' if we were unlucky and we hit a chunk boundary. - previous = empty - for item in _iter_basic_lines(): - if item == lf and previous[-1:] == cr: - previous += item - item = empty - if previous: - yield previous - previous = item - if previous: - yield previous - - -def make_chunk_iter( - stream: t.Union[t.Iterable[bytes], t.IO[bytes]], - separator: bytes, - limit: t.Optional[int] = None, - buffer_size: int = 10 * 1024, - cap_at_buffer: bool = False, -) -> t.Iterator[bytes]: - """Works like :func:`make_line_iter` but accepts a separator - which divides chunks. If you want newline based processing - you should use :func:`make_line_iter` instead as it - supports arbitrary newline markers. - - .. versionadded:: 0.8 - - .. versionadded:: 0.9 - added support for iterators as input stream. - - .. versionadded:: 0.11.10 - added support for the `cap_at_buffer` parameter. - - :param stream: the stream or iterate to iterate over. - :param separator: the separator that divides chunks. - :param limit: the limit in bytes for the stream. (Usually - content length. Not necessary if the `stream` - is otherwise already limited). - :param buffer_size: The optional buffer size. - :param cap_at_buffer: if this is set chunks are split if they are longer - than the buffer size. Internally this is implemented - that the buffer size might be exhausted by a factor - of two however. - """ - _iter = _make_chunk_iter(stream, limit, buffer_size) - - first_item = next(_iter, b"") - if not first_item: - return - - _iter = t.cast(t.Iterator[bytes], chain((first_item,), _iter)) - if isinstance(first_item, str): - separator = _to_str(separator) - _split = re.compile(f"({re.escape(separator)})").split - _join = "".join - else: - separator = _to_bytes(separator) - _split = re.compile(b"(" + re.escape(separator) + b")").split - _join = b"".join - - buffer: t.List[bytes] = [] - while True: - new_data = next(_iter, b"") - if not new_data: - break - chunks = _split(new_data) - new_buf: t.List[bytes] = [] - buf_size = 0 - for item in chain(buffer, chunks): - if item == separator: - yield _join(new_buf) - new_buf = [] - buf_size = 0 - else: - buf_size += len(item) - new_buf.append(item) - - if cap_at_buffer and buf_size >= buffer_size: - rv = _join(new_buf) - while len(rv) >= buffer_size: - yield rv[:buffer_size] - rv = rv[buffer_size:] - new_buf = [rv] - buf_size = len(rv) - - buffer = new_buf - if buffer: - yield _join(buffer) - - -class LimitedStream(io.IOBase): - """Wraps a stream so that it doesn't read more than n bytes. If the - stream is exhausted and the caller tries to get more bytes from it - :func:`on_exhausted` is called which by default returns an empty - string. The return value of that function is forwarded - to the reader function. So if it returns an empty string - :meth:`read` will return an empty string as well. - - The limit however must never be higher than what the stream can - output. Otherwise :meth:`readlines` will try to read past the - limit. - - .. admonition:: Note on WSGI compliance - - calls to :meth:`readline` and :meth:`readlines` are not - WSGI compliant because it passes a size argument to the - readline methods. Unfortunately the WSGI PEP is not safely - implementable without a size argument to :meth:`readline` - because there is no EOF marker in the stream. As a result - of that the use of :meth:`readline` is discouraged. - - For the same reason iterating over the :class:`LimitedStream` - is not portable. It internally calls :meth:`readline`. - - We strongly suggest using :meth:`read` only or using the - :func:`make_line_iter` which safely iterates line-based - over a WSGI input stream. - - :param stream: the stream to wrap. - :param limit: the limit for the stream, must not be longer than - what the string can provide if the stream does not - end with `EOF` (like `wsgi.input`) + .. versionchanged:: 2.3 + Implements ``io.RawIOBase`` rather than ``io.IOBase``. """ - def __init__(self, stream: t.IO[bytes], limit: int) -> None: - self._read = stream.read - self._readline = stream.readline + def __init__(self, stream: t.IO[bytes], limit: int, is_max: bool = False) -> None: + self._stream = stream self._pos = 0 self.limit = limit - - def __iter__(self) -> "LimitedStream": - return self + self._limit_is_max = is_max @property def is_exhausted(self) -> bool: - """If the stream is exhausted this attribute is `True`.""" + """Whether the current stream position has reached the limit.""" return self._pos >= self.limit - def on_exhausted(self) -> bytes: - """This is called when the stream tries to read past the limit. - The return value of this function is returned from the reading - function. + def on_exhausted(self) -> None: + """Called when attempting to read after the limit has been reached. + + The default behavior is to do nothing, unless the limit is a maximum, in which + case it raises :exc:`.RequestEntityTooLarge`. + + .. versionchanged:: 2.3 + Raises ``RequestEntityTooLarge`` if the limit is a maximum. + + .. versionchanged:: 2.3 + Any return value is ignored. """ - # Read null bytes from the stream so that we get the - # correct end of stream marker. - return self._read(0) + if self._limit_is_max: + raise RequestEntityTooLarge() - def on_disconnect(self) -> bytes: - """What should happen if a disconnect is detected? The return - value of this function is returned from read functions in case - the client went away. By default a - :exc:`~werkzeug.exceptions.ClientDisconnected` exception is raised. + def on_disconnect(self, error: Exception | None = None) -> None: + """Called when an attempted read receives zero bytes before the limit was + reached. This indicates that the client disconnected before sending the full + request body. + + The default behavior is to raise :exc:`.ClientDisconnected`, unless the limit is + a maximum and no error was raised. + + .. versionchanged:: 2.3 + Added the ``error`` parameter. Do nothing if the limit is a maximum and no + error was raised. + + .. versionchanged:: 2.3 + Any return value is ignored. """ - from .exceptions import ClientDisconnected + if not self._limit_is_max or error is not None: + raise ClientDisconnected() - raise ClientDisconnected() + # If the limit is a maximum, then we may have read zero bytes because the + # streaming body is complete. There's no way to distinguish that from the + # client disconnecting early. - def exhaust(self, chunk_size: int = 1024 * 64) -> None: - """Exhaust the stream. This consumes all the data left until the - limit is reached. + def exhaust(self) -> bytes: + """Exhaust the stream by reading until the limit is reached or the client + disconnects, returning the remaining data. - :param chunk_size: the size for a chunk. It will read the chunk - until the stream is exhausted and throw away - the results. + .. versionchanged:: 2.3 + Return the remaining data. + + .. versionchanged:: 2.2.3 + Handle case where wrapped stream returns fewer bytes than requested. """ - to_read = self.limit - self._pos - chunk = chunk_size - while to_read > 0: - chunk = min(to_read, chunk) - self.read(chunk) - to_read -= chunk + if not self.is_exhausted: + return self.readall() - def read(self, size: t.Optional[int] = None) -> bytes: - """Read `size` bytes or if size is not provided everything is read. + return b"" - :param size: the number of bytes read. - """ - if self._pos >= self.limit: - return self.on_exhausted() - if size is None or size == -1: # -1 is for consistence with file - size = self.limit - to_read = min(self.limit - self._pos, size) - try: - read = self._read(to_read) - except (OSError, ValueError): - return self.on_disconnect() - if to_read and len(read) != to_read: - return self.on_disconnect() - self._pos += len(read) - return read + def readinto(self, b: bytearray) -> int | None: # type: ignore[override] + size = len(b) + remaining = self.limit - self._pos - def readline(self, size: t.Optional[int] = None) -> bytes: - """Reads one line from the stream.""" - if self._pos >= self.limit: - return self.on_exhausted() - if size is None: - size = self.limit - self._pos + if remaining <= 0: + self.on_exhausted() + return 0 + + if hasattr(self._stream, "readinto"): + # Use stream.readinto if it's available. + if size <= remaining: + # The size fits in the remaining limit, use the buffer directly. + try: + out_size: int | None = self._stream.readinto(b) + except (OSError, ValueError) as e: + self.on_disconnect(error=e) + return 0 + else: + # Use a temp buffer with the remaining limit as the size. + temp_b = bytearray(remaining) + + try: + out_size = self._stream.readinto(temp_b) + except (OSError, ValueError) as e: + self.on_disconnect(error=e) + return 0 + + if out_size: + b[:out_size] = temp_b else: - size = min(size, self.limit - self._pos) - try: - line = self._readline(size) - except (ValueError, OSError): - return self.on_disconnect() - if size and not line: - return self.on_disconnect() - self._pos += len(line) - return line + # WSGI requires that stream.read is available. + try: + data = self._stream.read(min(size, remaining)) + except (OSError, ValueError) as e: + self.on_disconnect(error=e) + return 0 - def readlines(self, size: t.Optional[int] = None) -> t.List[bytes]: - """Reads a file into a list of strings. It calls :meth:`readline` - until the file is read to the end. It does support the optional - `size` argument if the underlying stream supports it for - `readline`. - """ - last_pos = self._pos - result = [] - if size is not None: - end = min(self.limit, last_pos + size) - else: - end = self.limit - while True: - if size is not None: - size -= last_pos - self._pos - if self._pos >= end: + out_size = len(data) + b[:out_size] = data + + if not out_size: + # Read zero bytes from the stream. + self.on_disconnect() + return 0 + + self._pos += out_size + return out_size + + def readall(self) -> bytes: + if self.is_exhausted: + self.on_exhausted() + return b"" + + out = bytearray() + + # The parent implementation uses "while True", which results in an extra read. + while not self.is_exhausted: + data = self.read(1024 * 64) + + # Stream may return empty before a max limit is reached. + if not data: break - result.append(self.readline(size)) - if size is not None: - last_pos = self._pos - return result + + out.extend(data) + + return bytes(out) def tell(self) -> int: - """Returns the position of the stream. + """Return the current stream position. .. versionadded:: 0.9 """ return self._pos - def __next__(self) -> bytes: - line = self.readline() - if not line: - raise StopIteration() - return line - def readable(self) -> bool: return True diff --git a/tests/conftest.py b/tests/conftest.py index 4b6c6cc..b73202c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -41,7 +41,9 @@ class DevServerClient: self.log = None def tail_log(self, path): - self.log = open(path) + # surrogateescape allows for handling of file streams + # containing junk binary values as normal text streams + self.log = open(path, errors="surrogateescape") self.log.read() def connect(self, **kwargs): @@ -101,15 +103,9 @@ def dev_server(xprocess, request, tmp_path): class Starter(ProcessStarter): args = [sys.executable, run_path, name, json.dumps(kwargs)] # Extend the existing env, otherwise Windows and CI fails. - # Modules will be imported from tmp_path for the reloader - # but any existing PYTHONPATH is preserved. + # Modules will be imported from tmp_path for the reloader. # Unbuffered output so the logs update immediately. - original_python_path = os.getenv("PYTHONPATH") - if original_python_path: - new_python_path = os.pathsep.join((original_python_path, str(tmp_path))) - else: - new_python_path = str(tmp_path) - env = {**os.environ, "PYTHONPATH": new_python_path, "PYTHONUNBUFFERED": "1"} + env = {**os.environ, "PYTHONPATH": str(tmp_path), "PYTHONUNBUFFERED": "1"} @cached_property def pattern(self): diff --git a/tests/live_apps/data_app.py b/tests/live_apps/data_app.py index a7158c7..561390a 100644 --- a/tests/live_apps/data_app.py +++ b/tests/live_apps/data_app.py @@ -5,12 +5,12 @@ from werkzeug.wrappers import Response @Request.application -def app(request): +def app(request: Request) -> Response: return Response( json.dumps( { "environ": request.environ, - "form": request.form, + "form": request.form.to_dict(), "files": {k: v.read().decode("utf8") for k, v in request.files.items()}, }, default=lambda x: str(x), diff --git a/tests/middleware/test_dispatcher.py b/tests/middleware/test_dispatcher.py index 5a25a6c..b7b9a77 100644 --- a/tests/middleware/test_dispatcher.py +++ b/tests/middleware/test_dispatcher.py @@ -1,4 +1,3 @@ -from werkzeug._internal import _to_bytes from werkzeug.middleware.dispatcher import DispatcherMiddleware from werkzeug.test import create_environ from werkzeug.test import run_wsgi_app @@ -11,7 +10,7 @@ def test_dispatcher(): def dummy_application(environ, start_response): start_response("200 OK", [("Content-Type", "text/plain")]) - yield _to_bytes(environ["SCRIPT_NAME"]) + yield environ["SCRIPT_NAME"].encode() app = DispatcherMiddleware( null_application, @@ -27,7 +26,7 @@ def test_dispatcher(): environ = create_environ(p) app_iter, status, headers = run_wsgi_app(app, environ) assert status == "200 OK" - assert b"".join(app_iter).strip() == _to_bytes(name) + assert b"".join(app_iter).strip() == name.encode() app_iter, status, headers = run_wsgi_app(app, create_environ("/missing")) assert status == "404 NOT FOUND" diff --git a/tests/middleware/test_profiler.py b/tests/middleware/test_profiler.py new file mode 100644 index 0000000..585aeb5 --- /dev/null +++ b/tests/middleware/test_profiler.py @@ -0,0 +1,50 @@ +import datetime +import os +from unittest.mock import ANY +from unittest.mock import MagicMock +from unittest.mock import patch + +from werkzeug.middleware.profiler import Profile +from werkzeug.middleware.profiler import ProfilerMiddleware +from werkzeug.test import Client + + +def dummy_application(environ, start_response): + start_response("200 OK", [("Content-Type", "text/plain")]) + return [b"Foo"] + + +def test_filename_format_function(): + # This should be called once with the generated file name + mock_capture_name = MagicMock() + + def filename_format(env): + now = datetime.datetime.fromtimestamp(env["werkzeug.profiler"]["time"]) + timestamp = now.strftime("%Y-%m-%d:%H:%M:%S") + path = ( + "_".join(token for token in env["PATH_INFO"].split("/") if token) or "ROOT" + ) + elapsed = env["werkzeug.profiler"]["elapsed"] + name = f"{timestamp}.{env['REQUEST_METHOD']}.{path}.{elapsed:.0f}ms.prof" + mock_capture_name(name=name) + return name + + client = Client( + ProfilerMiddleware( + dummy_application, + stream=None, + profile_dir="profiles", + filename_format=filename_format, + ) + ) + + # Replace the Profile class with a function that simulates an __init__() + # call and returns our mock instance. + mock_profile = MagicMock(wraps=Profile()) + mock_profile.dump_stats = MagicMock() + with patch("werkzeug.middleware.profiler.Profile", lambda: mock_profile): + client.get("/foo/bar") + + mock_capture_name.assert_called_once_with(name=ANY) + name = mock_capture_name.mock_calls[0].kwargs["name"] + mock_profile.dump_stats.assert_called_once_with(os.path.join("profiles", name)) diff --git a/tests/sansio/test_multipart.py b/tests/sansio/test_multipart.py index f9c48b4..35109d4 100644 --- a/tests/sansio/test_multipart.py +++ b/tests/sansio/test_multipart.py @@ -1,3 +1,5 @@ +import pytest + from werkzeug.datastructures import Headers from werkzeug.sansio.multipart import Data from werkzeug.sansio.multipart import Epilogue @@ -30,7 +32,7 @@ asdasd decoder.receive_data(data) decoder.receive_data(None) events = [decoder.next_event()] - while not isinstance(events[-1], Epilogue) and len(events) < 6: + while not isinstance(events[-1], Epilogue): events.append(decoder.next_event()) assert events == [ Preamble(data=b""), @@ -56,6 +58,57 @@ asdasd assert data == result +@pytest.mark.parametrize( + "data_start", + [ + b"A", + b"\n", + b"\r", + b"\r\n", + b"\n\r", + b"A\n", + b"A\r", + b"A\r\n", + b"A\n\r", + ], +) +@pytest.mark.parametrize("data_end", [b"", b"\r\n--foo"]) +def test_decoder_data_start_with_different_newline_positions( + data_start: bytes, data_end: bytes +) -> None: + boundary = b"foo" + data = ( + b"\r\n--foo\r\n" + b'Content-Disposition: form-data; name="test"; filename="testfile"\r\n' + b"Content-Type: application/octet-stream\r\n\r\n" + b"" + data_start + b"\r\nBCDE" + data_end + ) + decoder = MultipartDecoder(boundary) + decoder.receive_data(data) + events = [decoder.next_event()] + # We want to check up to data start event + while not isinstance(events[-1], Data): + events.append(decoder.next_event()) + expected = data_start if data_end == b"" else data_start + b"\r\nBCDE" + assert events == [ + Preamble(data=b""), + File( + name="test", + filename="testfile", + headers=Headers( + [ + ( + "Content-Disposition", + 'form-data; name="test"; filename="testfile"', + ), + ("Content-Type", "application/octet-stream"), + ] + ), + ), + Data(data=expected, more_data=True), + ] + + def test_chunked_boundaries() -> None: boundary = b"--boundary" decoder = MultipartDecoder(boundary) @@ -78,3 +131,58 @@ def test_chunked_boundaries() -> None: assert not event.more_data decoder.receive_data(None) assert isinstance(decoder.next_event(), Epilogue) + + +def test_empty_field() -> None: + boundary = b"foo" + decoder = MultipartDecoder(boundary) + data = """ +--foo +Content-Disposition: form-data; name="text" +Content-Type: text/plain; charset="UTF-8" + +Some Text +--foo +Content-Disposition: form-data; name="empty" +Content-Type: text/plain; charset="UTF-8" + +--foo-- + """.replace( + "\n", "\r\n" + ).encode( + "utf-8" + ) + decoder.receive_data(data) + decoder.receive_data(None) + events = [decoder.next_event()] + while not isinstance(events[-1], Epilogue): + events.append(decoder.next_event()) + assert events == [ + Preamble(data=b""), + Field( + name="text", + headers=Headers( + [ + ("Content-Disposition", 'form-data; name="text"'), + ("Content-Type", 'text/plain; charset="UTF-8"'), + ] + ), + ), + Data(data=b"Some Text", more_data=False), + Field( + name="empty", + headers=Headers( + [ + ("Content-Disposition", 'form-data; name="empty"'), + ("Content-Type", 'text/plain; charset="UTF-8"'), + ] + ), + ), + Data(data=b"", more_data=False), + Epilogue(data=b" "), + ] + encoder = MultipartEncoder(boundary) + result = b"" + for event in events: + result += encoder.send_event(event) + assert data == result diff --git a/tests/sansio/test_request.py b/tests/sansio/test_request.py index 310b244..4f4bbd6 100644 --- a/tests/sansio/test_request.py +++ b/tests/sansio/test_request.py @@ -12,6 +12,10 @@ from werkzeug.sansio.request import Request (Headers({"Transfer-Encoding": "chunked", "Content-Length": "6"}), None), (Headers({"Transfer-Encoding": "something", "Content-Length": "6"}), 6), (Headers({"Content-Length": "6"}), 6), + (Headers({"Content-Length": "-6"}), 0), + (Headers({"Content-Length": "+123"}), 0), + (Headers({"Content-Length": "1_23"}), 0), + (Headers({"Content-Length": "🯱🯲🯳"}), 0), (Headers(), None), ], ) diff --git a/tests/sansio/test_utils.py b/tests/sansio/test_utils.py index 8c8faa6..04d02e4 100644 --- a/tests/sansio/test_utils.py +++ b/tests/sansio/test_utils.py @@ -1,7 +1,10 @@ +from __future__ import annotations + import typing as t import pytest +from werkzeug.sansio.utils import get_content_length from werkzeug.sansio.utils import get_host @@ -30,3 +33,23 @@ def test_get_host( expected: str, ) -> None: assert get_host(scheme, host_header, server) == expected + + +@pytest.mark.parametrize( + ("http_content_length", "http_transfer_encoding", "expected"), + [ + ("2", None, 2), + (" 2", None, 2), + ("2 ", None, 2), + (None, None, None), + (None, "chunked", None), + ("a", None, 0), + ("-2", None, 0), + ], +) +def test_get_content_length( + http_content_length: str | None, + http_transfer_encoding: str | None, + expected: int | None, +) -> None: + assert get_content_length(http_content_length, http_transfer_encoding) == expected diff --git a/tests/test_datastructures.py b/tests/test_datastructures.py index 7f63b64..5206aa6 100644 --- a/tests/test_datastructures.py +++ b/tests/test_datastructures.py @@ -63,7 +63,7 @@ class _MutableMultiDictTests: d = create_instance() s = pickle.dumps(d, protocol) ud = pickle.loads(s) - assert type(ud) == type(d) + assert type(ud) == type(d) # noqa: E721 assert ud == d alternative = pickle.dumps(create_instance("werkzeug"), protocol) assert pickle.loads(alternative) == d @@ -731,16 +731,6 @@ class TestHeaders: h[:] = [(k, v) for k, v in h if k.startswith("X-")] assert list(h) == [("X-Foo-Poo", "bleh"), ("X-Forwarded-For", "192.168.0.123")] - def test_bytes_operations(self): - h = self.storage_class() - h.set("X-Foo-Poo", "bleh") - h.set("X-Whoops", b"\xff") - h.set(b"X-Bytes", b"something") - - assert h.get("x-foo-poo", as_bytes=True) == b"bleh" - assert h.get("x-whoops", as_bytes=True) == b"\xff" - assert h.get("x-bytes") == "something" - def test_extend(self): h = self.storage_class([("a", "0"), ("b", "1"), ("c", "2")]) h.extend(ds.Headers([("a", "3"), ("a", "4")])) @@ -791,13 +781,6 @@ class TestHeaders: assert key == "Key" assert value == "Value" - def test_to_wsgi_list_bytes(self): - h = self.storage_class() - h.set(b"Key", b"Value") - for key, value in h.to_wsgi_list(): - assert key == "Key" - assert value == "Value" - def test_equality(self): # test equality, given keys are case insensitive h1 = self.storage_class() @@ -853,13 +836,6 @@ class TestEnvironHeaders: assert headers["Foo"] == "\xe2\x9c\x93" assert next(iter(headers)) == ("Foo", "\xe2\x9c\x93") - def test_bytes_operations(self): - foo_val = "\xff" - h = self.storage_class({"HTTP_X_FOO": foo_val}) - - assert h.get("x-foo", as_bytes=True) == b"\xff" - assert h.get("x-foo") == "\xff" - class TestHeaderSet: storage_class = ds.HeaderSet diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index d8fed96..e4ee586 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -96,10 +96,8 @@ def test_method_not_allowed_methods(): def test_unauthorized_www_authenticate(): - basic = WWWAuthenticate() - basic.set_basic("test") - digest = WWWAuthenticate() - digest.set_digest("test", "test") + basic = WWWAuthenticate("basic", {"realm": "test"}) + digest = WWWAuthenticate("digest", {"realm": "test", "nonce": "test"}) exc = exceptions.Unauthorized(www_authenticate=basic) h = Headers(exc.get_headers({})) diff --git a/tests/test_formparser.py b/tests/test_formparser.py index 49010b4..1dcb167 100644 --- a/tests/test_formparser.py +++ b/tests/test_formparser.py @@ -69,37 +69,23 @@ class TestFormParser: req.max_form_memory_size = 400 assert req.form["foo"] == "Hello World" + input_stream = io.BytesIO(b"foo=123456") + req = Request.from_values( + input_stream=input_stream, + content_type="application/x-www-form-urlencoded", + method="POST", + ) + req.max_content_length = 4 + pytest.raises(RequestEntityTooLarge, lambda: req.form["foo"]) + # content-length was set, so request could exit early without reading anything + assert input_stream.read() == b"foo=123456" + data = ( b"--foo\r\nContent-Disposition: form-field; name=foo\r\n\r\n" b"Hello World\r\n" b"--foo\r\nContent-Disposition: form-field; name=bar\r\n\r\n" b"bar=baz\r\n--foo--" ) - req = Request.from_values( - input_stream=io.BytesIO(data), - content_length=len(data), - content_type="multipart/form-data; boundary=foo", - method="POST", - ) - req.max_content_length = 4 - pytest.raises(RequestEntityTooLarge, lambda: req.form["foo"]) - - # when the request entity is too large, the input stream should be - # drained so that firefox (and others) do not report connection reset - # when run through gunicorn - # a sufficiently large stream is necessary for block-based reads - input_stream = io.BytesIO(b"foo=" + b"x" * 128 * 1024) - req = Request.from_values( - input_stream=input_stream, - content_length=len(data), - content_type="multipart/form-data; boundary=foo", - method="POST", - ) - req.max_content_length = 4 - pytest.raises(RequestEntityTooLarge, lambda: req.form["foo"]) - # ensure that the stream is exhausted - assert input_stream.read() == b"" - req = Request.from_values( input_stream=io.BytesIO(data), content_length=len(data), @@ -127,6 +113,22 @@ class TestFormParser: req.max_form_memory_size = 400 assert req.form["foo"] == "Hello World" + req = Request.from_values( + input_stream=io.BytesIO(data), + content_length=len(data), + content_type="multipart/form-data; boundary=foo", + method="POST", + ) + req.max_form_parts = 1 + pytest.raises(RequestEntityTooLarge, lambda: req.form["foo"]) + + def test_x_www_urlencoded_max_form_parts(self): + r = Request.from_values(method="POST", data={"a": 1, "b": 2}) + r.max_form_parts = 1 + + assert r.form["a"] == "1" + assert r.form["b"] == "2" + def test_missing_multipart_boundary(self): data = ( b"--foo\r\nContent-Disposition: form-field; name=foo\r\n\r\n" diff --git a/tests/test_http.py b/tests/test_http.py index 3760dc1..bbd51ba 100644 --- a/tests/test_http.py +++ b/tests/test_http.py @@ -1,4 +1,5 @@ import base64 +import urllib.parse from datetime import date from datetime import datetime from datetime import timedelta @@ -9,6 +10,8 @@ import pytest from werkzeug import datastructures from werkzeug import http from werkzeug._internal import _wsgi_encoding_dance +from werkzeug.datastructures import Authorization +from werkzeug.datastructures import WWWAuthenticate from werkzeug.test import create_environ @@ -21,6 +24,10 @@ class TestHTTPUtility: pytest.raises(ValueError, a.index, "de") assert a.to_header() == "en-us,ru;q=0.5" + def test_accept_parameter_with_space(self): + a = http.parse_accept_header('application/x-special; z="a b";q=0.5') + assert a['application/x-special; z="a b"'] == 0.5 + def test_mime_accept(self): a = http.parse_accept_header( "text/xml,application/xml," @@ -88,9 +95,17 @@ class TestHTTPUtility: hs.add("Foo") assert hs.to_header() == 'foo, Bar, "Blah baz", Hehe' - def test_list_header(self): - hl = http.parse_list_header("foo baz, blah") - assert hl == ["foo baz", "blah"] + @pytest.mark.parametrize( + ("value", "expect"), + [ + ("a b", ["a b"]), + ("a b, c", ["a b", "c"]), + ('a b, "c, d"', ["a b", "c, d"]), + ('"a\\"b", c', ['a"b', "c"]), + ], + ) + def test_list_header(self, value, expect): + assert http.parse_list_header(value) == expect def test_dict_header(self): d = http.parse_dict_header('foo="bar baz", blah=42') @@ -133,33 +148,30 @@ class TestHTTPUtility: assert csp.img_src is None def test_authorization_header(self): - a = http.parse_authorization_header("Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==") + a = Authorization.from_header("Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==") assert a.type == "basic" assert a.username == "Aladdin" assert a.password == "open sesame" - a = http.parse_authorization_header( - "Basic 0YDRg9GB0YHQutC40IE60JHRg9C60LLRiw==" - ) + a = Authorization.from_header("Basic 0YDRg9GB0YHQutC40IE60JHRg9C60LLRiw==") assert a.type == "basic" assert a.username == "русскиЁ" assert a.password == "Буквы" - a = http.parse_authorization_header("Basic 5pmu6YCa6K+dOuS4reaWhw==") + a = Authorization.from_header("Basic 5pmu6YCa6K+dOuS4reaWhw==") assert a.type == "basic" assert a.username == "普通话" assert a.password == "中文" - a = http.parse_authorization_header( - '''Digest username="Mufasa", - realm="testrealm@host.invalid", - nonce="dcd98b7102dd2f0e8b11d0f600bfb0c093", - uri="/dir/index.html", - qop=auth, - nc=00000001, - cnonce="0a4f113b", - response="6629fae49393a05397450978507c4ef1", - opaque="5ccc069c403ebaf9f0171e9517f40e41"''' + a = Authorization.from_header( + 'Digest username="Mufasa",' + ' realm="testrealm@host.invalid",' + ' nonce="dcd98b7102dd2f0e8b11d0f600bfb0c093",' + ' uri="/dir/index.html",' + " qop=auth, nc=00000001," + ' cnonce="0a4f113b",' + ' response="6629fae49393a05397450978507c4ef1",' + ' opaque="5ccc069c403ebaf9f0171e9517f40e41"' ) assert a.type == "digest" assert a.username == "Mufasa" @@ -172,13 +184,13 @@ class TestHTTPUtility: assert a.response == "6629fae49393a05397450978507c4ef1" assert a.opaque == "5ccc069c403ebaf9f0171e9517f40e41" - a = http.parse_authorization_header( - '''Digest username="Mufasa", - realm="testrealm@host.invalid", - nonce="dcd98b7102dd2f0e8b11d0f600bfb0c093", - uri="/dir/index.html", - response="e257afa1414a3340d93d30955171dd0e", - opaque="5ccc069c403ebaf9f0171e9517f40e41"''' + a = Authorization.from_header( + 'Digest username="Mufasa",' + ' realm="testrealm@host.invalid",' + ' nonce="dcd98b7102dd2f0e8b11d0f600bfb0c093",' + ' uri="/dir/index.html",' + ' response="e257afa1414a3340d93d30955171dd0e",' + ' opaque="5ccc069c403ebaf9f0171e9517f40e41"' ) assert a.type == "digest" assert a.username == "Mufasa" @@ -188,41 +200,87 @@ class TestHTTPUtility: assert a.response == "e257afa1414a3340d93d30955171dd0e" assert a.opaque == "5ccc069c403ebaf9f0171e9517f40e41" - assert http.parse_authorization_header("") is None - assert http.parse_authorization_header(None) is None - assert http.parse_authorization_header("foo") is None + assert Authorization.from_header("") is None + assert Authorization.from_header(None) is None + assert Authorization.from_header("foo").type == "foo" + + def test_authorization_token_padding(self): + # padded with = + token = base64.b64encode(b"This has base64 padding").decode() + a = Authorization.from_header(f"Token {token}") + assert a.type == "token" + assert a.token == token + + # padded with == + token = base64.b64encode(b"This has base64 padding..").decode() + a = Authorization.from_header(f"Token {token}") + assert a.type == "token" + assert a.token == token + + def test_authorization_basic_incorrect_padding(self): + assert Authorization.from_header("Basic foo") is None def test_bad_authorization_header_encoding(self): """If the base64 encoded bytes can't be decoded as UTF-8""" content = base64.b64encode(b"\xffser:pass").decode() - assert http.parse_authorization_header(f"Basic {content}") is None + assert Authorization.from_header(f"Basic {content}") is None + + def test_authorization_eq(self): + basic1 = Authorization.from_header("Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==") + basic2 = Authorization( + "basic", {"username": "Aladdin", "password": "open sesame"} + ) + assert basic1 == basic2 + bearer1 = Authorization.from_header("Bearer abc") + bearer2 = Authorization("bearer", token="abc") + assert bearer1 == bearer2 + assert basic1 != bearer1 + assert basic1 != object() def test_www_authenticate_header(self): - wa = http.parse_www_authenticate_header('Basic realm="WallyWorld"') + wa = WWWAuthenticate.from_header('Basic realm="WallyWorld"') assert wa.type == "basic" assert wa.realm == "WallyWorld" wa.realm = "Foo Bar" assert wa.to_header() == 'Basic realm="Foo Bar"' - wa = http.parse_www_authenticate_header( - '''Digest - realm="testrealm@host.com", - qop="auth,auth-int", - nonce="dcd98b7102dd2f0e8b11d0f600bfb0c093", - opaque="5ccc069c403ebaf9f0171e9517f40e41"''' + wa = WWWAuthenticate.from_header( + 'Digest realm="testrealm@host.com",' + ' qop="auth,auth-int",' + ' nonce="dcd98b7102dd2f0e8b11d0f600bfb0c093",' + ' opaque="5ccc069c403ebaf9f0171e9517f40e41"' ) assert wa.type == "digest" assert wa.realm == "testrealm@host.com" - assert "auth" in wa.qop - assert "auth-int" in wa.qop + assert wa.parameters["qop"] == "auth,auth-int" assert wa.nonce == "dcd98b7102dd2f0e8b11d0f600bfb0c093" assert wa.opaque == "5ccc069c403ebaf9f0171e9517f40e41" - wa = http.parse_www_authenticate_header("broken") - assert wa.type == "broken" + assert WWWAuthenticate.from_header("broken").type == "broken" + assert WWWAuthenticate.from_header("") is None - assert not http.parse_www_authenticate_header("").type - assert not http.parse_www_authenticate_header("") + def test_www_authenticate_token_padding(self): + # padded with = + token = base64.b64encode(b"This has base64 padding").decode() + a = WWWAuthenticate.from_header(f"Token {token}") + assert a.type == "token" + assert a.token == token + + # padded with == + token = base64.b64encode(b"This has base64 padding..").decode() + a = WWWAuthenticate.from_header(f"Token {token}") + assert a.type == "token" + assert a.token == token + + def test_www_authenticate_eq(self): + basic1 = WWWAuthenticate.from_header("Basic realm=abc") + basic2 = WWWAuthenticate("basic", {"realm": "abc"}) + assert basic1 == basic2 + token1 = WWWAuthenticate.from_header("Token abc") + token2 = WWWAuthenticate("token", token="abc") + assert token1 == token2 + assert basic1 != token1 + assert basic1 != object() def test_etags(self): assert http.quote_etag("foo") == '"foo"' @@ -274,68 +332,63 @@ class TestHTTPUtility: http.remove_hop_by_hop_headers(headers2) assert headers2 == datastructures.Headers([("Foo", "bar")]) - def test_parse_options_header(self): - assert http.parse_options_header(None) == ("", {}) - assert http.parse_options_header("") == ("", {}) - assert http.parse_options_header(r'something; foo="other\"thing"') == ( - "something", - {"foo": 'other"thing'}, - ) - assert http.parse_options_header(r'something; foo="other\"thing"; meh=42') == ( - "something", - {"foo": 'other"thing', "meh": "42"}, - ) - assert http.parse_options_header( - r'something; foo="other\"thing"; meh=42; bleh' - ) == ("something", {"foo": 'other"thing', "meh": "42", "bleh": None}) - assert http.parse_options_header( - 'something; foo="other;thing"; meh=42; bleh' - ) == ("something", {"foo": "other;thing", "meh": "42", "bleh": None}) - assert http.parse_options_header('something; foo="otherthing"; meh=; bleh') == ( - "something", - {"foo": "otherthing", "meh": None, "bleh": None}, - ) - # Issue #404 - assert http.parse_options_header( - 'multipart/form-data; name="foo bar"; filename="bar foo"' - ) == ("multipart/form-data", {"name": "foo bar", "filename": "bar foo"}) - # Examples from RFC - assert http.parse_options_header("audio/*; q=0.2, audio/basic") == ( - "audio/*", - {"q": "0.2"}, - ) + @pytest.mark.parametrize( + ("value", "expect"), + [ + (None, ""), + ("", ""), + (";a=b", ""), + ("v", "v"), + ("v;", "v"), + ], + ) + def test_parse_options_header_empty(self, value, expect): + assert http.parse_options_header(value) == (expect, {}) - assert http.parse_options_header( - "text/plain; q=0.5, text/html\n text/x-dvi; q=0.8, text/x-c" - ) == ("text/plain", {"q": "0.5"}) - # Issue #932 - assert http.parse_options_header( - "form-data; name=\"a_file\"; filename*=UTF-8''" - '"%c2%a3%20and%20%e2%82%ac%20rates"' - ) == ("form-data", {"name": "a_file", "filename": "\xa3 and \u20ac rates"}) - assert http.parse_options_header( - "form-data; name*=UTF-8''\"%C5%AAn%C4%ADc%C5%8Dde%CC%BD\"; " - 'filename="some_file.txt"' - ) == ( - "form-data", - {"name": "\u016an\u012dc\u014dde\u033d", "filename": "some_file.txt"}, - ) - - def test_parse_options_header_value_with_quotes(self): - assert http.parse_options_header( - 'form-data; name="file"; filename="t\'es\'t.txt"' - ) == ("form-data", {"name": "file", "filename": "t'es't.txt"}) - assert http.parse_options_header( - "form-data; name=\"file\"; filename*=UTF-8''\"'🐍'.txt\"" - ) == ("form-data", {"name": "file", "filename": "'🐍'.txt"}) + @pytest.mark.parametrize( + ("value", "expect"), + [ + ("v;a=b;c=d;", {"a": "b", "c": "d"}), + ("v; ; a=b ; ", {"a": "b"}), + ("v;a", {}), + ("v;a=", {}), + ("v;=b", {}), + ('v;a="b"', {"a": "b"}), + ("v;a=µ", {}), + ('v;a="\';\'";b="µ";', {"a": "';'", "b": "µ"}), + ('v;a="b c"', {"a": "b c"}), + # HTTP headers use \" for internal " + ('v;a="b\\"c";d=e', {"a": 'b"c', "d": "e"}), + # HTTP headers use \\ for internal \ + ('v;a="c:\\\\"', {"a": "c:\\"}), + # Invalid trailing slash in quoted part is left as-is. + ('v;a="c:\\"', {"a": "c:\\"}), + ('v;a="b\\\\\\"c"', {"a": 'b\\"c'}), + # multipart form data uses %22 for internal " + ('v;a="b%22c"', {"a": 'b"c'}), + ("v;a*=b", {"a": "b"}), + ("v;a*=ASCII'en'b", {"a": "b"}), + ("v;a*=US-ASCII''%62", {"a": "b"}), + ("v;a*=UTF-8''%C2%B5", {"a": "µ"}), + ("v;a*=US-ASCII''%C2%B5", {"a": "��"}), + ("v;a*=BAD''%62", {"a": "%62"}), + ("v;a*=UTF-8'''%F0%9F%90%8D'.txt", {"a": "'🐍'.txt"}), + ('v;a="🐍.txt"', {"a": "🐍.txt"}), + ("v;a*0=b;a*1=c;d=e", {"a": "bc", "d": "e"}), + ("v;a*0*=b", {"a": "b"}), + ("v;a*0*=UTF-8''b;a*1=c;a*2*=%C2%B5", {"a": "bcµ"}), + ], + ) + def test_parse_options_header(self, value, expect) -> None: + assert http.parse_options_header(value) == ("v", expect) def test_parse_options_header_broken_values(self): # Issue #995 assert http.parse_options_header(" ") == ("", {}) - assert http.parse_options_header(" , ") == ("", {}) + assert http.parse_options_header(" , ") == (",", {}) assert http.parse_options_header(" ; ") == ("", {}) - assert http.parse_options_header(" ,; ") == ("", {}) - assert http.parse_options_header(" , a ") == ("", {}) + assert http.parse_options_header(" ,; ") == (",", {}) + assert http.parse_options_header(" , a ") == (", a", {}) assert http.parse_options_header(" ; a ") == ("", {}) def test_parse_options_header_case_insensitive(self): @@ -344,16 +397,12 @@ class TestHTTPUtility: def test_dump_options_header(self): assert http.dump_options_header("foo", {"bar": 42}) == "foo; bar=42" - assert http.dump_options_header("foo", {"bar": 42, "fizz": None}) in ( - "foo; bar=42; fizz", - "foo; fizz; bar=42", - ) + assert "fizz" not in http.dump_options_header("foo", {"bar": 42, "fizz": None}) def test_dump_header(self): assert http.dump_header([1, 2, 3]) == "1, 2, 3" - assert http.dump_header([1, 2, 3], allow_token=False) == '"1", "2", "3"' - assert http.dump_header({"foo": "bar"}, allow_token=False) == 'foo="bar"' assert http.dump_header({"foo": "bar"}) == "foo=bar" + assert http.dump_header({"foo*": "UTF-8''bar"}) == "foo*=UTF-8''bar" def test_is_resource_modified(self): env = create_environ() @@ -411,7 +460,8 @@ class TestHTTPUtility: def test_parse_cookie(self): cookies = http.parse_cookie( "dismiss-top=6; CP=null*; PHPSESSID=0a539d42abc001cdc762809248d4beed;" - 'a=42; b="\\";"; ; fo234{=bar;blub=Blah; "__Secure-c"=d' + 'a=42; b="\\";"; ; fo234{=bar;blub=Blah; "__Secure-c"=d;' + "==__Host-eq=bad;__Host-eq=good;" ) assert cookies.to_dict() == { "CP": "null*", @@ -422,6 +472,7 @@ class TestHTTPUtility: "fo234{": "bar", "blub": "Blah", '"__Secure-c"': "d", + "__Host-eq": "good", } def test_dump_cookie(self): @@ -435,7 +486,7 @@ class TestHTTPUtility: 'foo="bar baz blub"', } assert http.dump_cookie("key", "xxx/") == "key=xxx/; Path=/" - assert http.dump_cookie("key", "xxx=") == "key=xxx=; Path=/" + assert http.dump_cookie("key", "xxx=", path=None) == "key=xxx=" def test_bad_cookies(self): cookies = http.parse_cookie( @@ -458,9 +509,9 @@ class TestHTTPUtility: def test_cookie_quoting(self): val = http.dump_cookie("foo", "?foo") - assert val == 'foo="?foo"; Path=/' - assert http.parse_cookie(val).to_dict() == {"foo": "?foo", "Path": "/"} - assert http.parse_cookie(r'foo="foo\054bar"').to_dict(), {"foo": "foo,bar"} + assert val == "foo=?foo; Path=/" + assert http.parse_cookie(val)["foo"] == "?foo" + assert http.parse_cookie(r'foo="foo\054bar"')["foo"] == "foo,bar" def test_parse_set_cookie_directive(self): val = 'foo="?foo"; version="0.1";' @@ -482,7 +533,7 @@ class TestHTTPUtility: def test_cookie_unicode_keys(self): # Yes, this is technically against the spec but happens val = http.dump_cookie("fö", "fö") - assert val == _wsgi_encoding_dance('fö="f\\303\\266"; Path=/', "utf-8") + assert val == _wsgi_encoding_dance('fö="f\\303\\266"; Path=/') cookies = http.parse_cookie(val) assert cookies["fö"] == "fö" @@ -495,38 +546,30 @@ class TestHTTPUtility: val = http.dump_cookie("foo", "bar", domain="\N{SNOWMAN}.com") assert val == "foo=bar; Domain=xn--n3h.com; Path=/" - val = http.dump_cookie("foo", "bar", domain=".\N{SNOWMAN}.com") - assert val == "foo=bar; Domain=.xn--n3h.com; Path=/" + val = http.dump_cookie("foo", "bar", domain="foo.com") + assert val == "foo=bar; Domain=foo.com; Path=/" - val = http.dump_cookie("foo", "bar", domain=".foo.com") - assert val == "foo=bar; Domain=.foo.com; Path=/" - - def test_cookie_maxsize(self, recwarn): + def test_cookie_maxsize(self): val = http.dump_cookie("foo", "bar" * 1360 + "b") - assert len(recwarn) == 0 assert len(val) == 4093 - http.dump_cookie("foo", "bar" * 1360 + "ba") - assert len(recwarn) == 1 - w = recwarn.pop() - assert "cookie is too large" in str(w.message) + with pytest.warns(UserWarning, match="cookie is too large"): + http.dump_cookie("foo", "bar" * 1360 + "ba") - http.dump_cookie("foo", b"w" * 502, max_size=512) - assert len(recwarn) == 1 - w = recwarn.pop() - assert "the limit is 512 bytes" in str(w.message) + with pytest.warns(UserWarning, match="the limit is 512 bytes"): + http.dump_cookie("foo", "w" * 501, max_size=512) @pytest.mark.parametrize( ("samesite", "expected"), ( - ("strict", "foo=bar; Path=/; SameSite=Strict"), - ("lax", "foo=bar; Path=/; SameSite=Lax"), - ("none", "foo=bar; Path=/; SameSite=None"), - (None, "foo=bar; Path=/"), + ("strict", "foo=bar; SameSite=Strict"), + ("lax", "foo=bar; SameSite=Lax"), + ("none", "foo=bar; SameSite=None"), + (None, "foo=bar"), ), ) def test_cookie_samesite_attribute(self, samesite, expected): - value = http.dump_cookie("foo", "bar", samesite=samesite) + value = http.dump_cookie("foo", "bar", samesite=samesite, path=None) assert value == expected def test_cookie_samesite_invalid(self): @@ -619,6 +662,9 @@ class TestRange: rv = http.parse_content_range_header("bytes 0-98/*asdfsa") assert rv is None + rv = http.parse_content_range_header("bytes */-1") + assert rv is None + rv = http.parse_content_range_header("bytes 0-99/100") assert rv.to_header() == "bytes 0-99/100" rv.start = None @@ -656,7 +702,7 @@ class TestRegression: ], ) def test_authorization_to_header(value: str) -> None: - parsed = http.parse_authorization_header(value) + parsed = Authorization.from_header(value) assert parsed is not None assert parsed.to_header() == value @@ -715,3 +761,32 @@ def test_parse_date(value, expect): ) def test_http_date(value, expect): assert http.http_date(value) == expect + + +@pytest.mark.parametrize("value", [".5", "+0.5", "0.5_1", "🯰.🯵"]) +def test_accept_invalid_float(value): + quoted = urllib.parse.quote(value) + + if quoted == value: + q = f"q={value}" + else: + q = f"q*=UTF-8''{value}" + + a = http.parse_accept_header(f"en,jp;{q}") + assert list(a.values()) == ["en"] + + +def test_accept_valid_int_one_zero(): + assert http.parse_accept_header("en;q=1") == http.parse_accept_header("en;q=1.0") + assert http.parse_accept_header("en;q=0") == http.parse_accept_header("en;q=0.0") + assert http.parse_accept_header("en;q=5") == http.parse_accept_header("en;q=5.0") + + +@pytest.mark.parametrize("value", ["🯱🯲🯳", "+1-", "1-1_23"]) +def test_range_invalid_int(value): + assert http.parse_range_header(value) is None + + +@pytest.mark.parametrize("value", ["*/🯱🯲🯳", "1-+2/3", "1_23-125/*"]) +def test_content_range_invalid_int(value): + assert http.parse_content_range_header(f"bytes {value}") is None diff --git a/tests/test_internal.py b/tests/test_internal.py index 6e673fd..edae35b 100644 --- a/tests/test_internal.py +++ b/tests/test_internal.py @@ -1,21 +1,8 @@ -from warnings import filterwarnings -from warnings import resetwarnings - -import pytest - -from werkzeug import _internal as internal from werkzeug.test import create_environ from werkzeug.wrappers import Request from werkzeug.wrappers import Response -def test_easteregg(): - req = Request.from_values("/?macgybarchakku") - resp = Response.force_type(internal._easteregg(None), req) - assert b"About Werkzeug" in resp.get_data() - assert b"the Swiss Army knife of Python web development" in resp.get_data() - - def test_wrapper_internals(): req = Request.from_values(data={"foo": "bar"}, method="POST") req._load_form_data() @@ -34,23 +21,10 @@ def test_wrapper_internals(): resp.response = iter(["Test"]) assert repr(resp) == "" - # string data does not set content length response = Response(["Hällo Wörld"]) headers = response.get_wsgi_headers(create_environ()) - assert "Content-Length" not in headers + assert "Content-Length" in headers response = Response(["Hällo Wörld".encode()]) headers = response.get_wsgi_headers(create_environ()) assert "Content-Length" in headers - - # check for internal warnings - filterwarnings("error", category=Warning) - response = Response() - environ = create_environ() - response.response = "What the...?" - pytest.raises(Warning, lambda: list(response.iter_encoded())) - pytest.raises(Warning, lambda: list(response.get_app_iter(environ))) - response.direct_passthrough = True - pytest.raises(Warning, lambda: list(response.iter_encoded())) - pytest.raises(Warning, lambda: list(response.get_app_iter(environ))) - resetwarnings() diff --git a/tests/test_routing.py b/tests/test_routing.py index 15d25a7..65d2a5f 100644 --- a/tests/test_routing.py +++ b/tests/test_routing.py @@ -163,6 +163,7 @@ def test_strict_slashes_redirect(): r.Rule("/bar/", endpoint="get", methods=["GET"]), r.Rule("/bar", endpoint="post", methods=["POST"]), r.Rule("/foo/", endpoint="foo", methods=["POST"]), + r.Rule("//", endpoint="path", methods=["GET"]), ] ) adapter = map.bind("example.org", "/") @@ -170,6 +171,7 @@ def test_strict_slashes_redirect(): # Check if the actual routes works assert adapter.match("/bar/", method="GET") == ("get", {}) assert adapter.match("/bar", method="POST") == ("post", {}) + assert adapter.match("/abc/", method="GET") == ("path", {"var": "abc"}) # Check if exceptions are correct pytest.raises(r.RequestRedirect, adapter.match, "/bar", method="GET") @@ -177,6 +179,9 @@ def test_strict_slashes_redirect(): with pytest.raises(r.RequestRedirect) as error_info: adapter.match("/foo", method="POST") assert error_info.value.code == 308 + with pytest.raises(r.RequestRedirect) as error_info: + adapter.match("/abc", method="GET") + assert error_info.value.new_url == "http://example.org/abc/" # Check differently defined order map = r.Map( @@ -581,7 +586,8 @@ def test_server_name_interpolation(): with pytest.warns(UserWarning): adapter = map.bind_to_environ(env, server_name="foo") - assert adapter.subdomain == "" + + assert adapter.subdomain == "" def test_rule_emptying(): @@ -742,7 +748,7 @@ def test_uuid_converter(): m = r.Map([r.Rule("/a/", endpoint="a")]) a = m.bind("example.org", "/") route, kwargs = a.match("/a/a8098c1a-f86e-11da-bd1a-00112444be1e") - assert type(kwargs["a_uuid"]) == uuid.UUID + assert type(kwargs["a_uuid"]) == uuid.UUID # noqa: E721 def test_converter_with_tuples(): @@ -773,6 +779,35 @@ def test_converter_with_tuples(): assert kwargs["foo"] == ("qwert", "yuiop") +def test_nested_regex_groups(): + """ + Regression test for https://github.com/pallets/werkzeug/issues/2590 + """ + + class RegexConverter(r.BaseConverter): + def __init__(self, url_map, *items): + super().__init__(url_map) + self.part_isolating = False + self.regex = items[0] + + # This is a regex pattern with nested groups + DATE_PATTERN = r"((\d{8}T\d{6}([.,]\d{1,3})?)|(\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}([.,]\d{1,3})?))Z" # noqa: B950 + + map = r.Map( + [ + r.Rule( + f"///", + endpoint="handler", + ) + ], + converters={"regex": RegexConverter}, + ) + a = map.bind("example.org", "/") + route, kwargs = a.match("/2023-02-16T23:36:36.266Z/2023-02-16T23:46:36.266Z/") + assert kwargs["start"] == "2023-02-16T23:36:36.266Z" + assert kwargs["end"] == "2023-02-16T23:46:36.266Z" + + def test_anyconverter(): m = r.Map( [ @@ -800,6 +835,20 @@ def test_any_converter_build_validates_value() -> None: assert str(exc.value) == "'invalid' is not one of 'patient', 'provider'" +def test_part_isolating_default() -> None: + class TwoConverter(r.BaseConverter): + regex = r"\w+/\w+" + + def to_python(self, value: str) -> t.Any: + return value.split("/") + + m = r.Map( + [r.Rule("//", endpoint="two")], converters={"two": TwoConverter} + ) + a = m.bind("localhost") + assert a.match("/a/b/") == ("two", {"values": ["a", "b"]}) + + @pytest.mark.parametrize( ("endpoint", "value", "expect"), [ @@ -874,6 +923,7 @@ def test_build_values_multidict(endpoint, value, expect): ([1, 2], "?v=1&v=2"), ([1, None, 2], "?v=1&v=2"), ([1, "", 2], "?v=1&v=&v=2"), + ("1+2", "?v=1%2B2"), ], ) def test_build_append_unknown_dict(value, expect): @@ -910,8 +960,7 @@ def test_build_drop_none(): adapter = map.bind("", "/") params = {"flub": None, "flop": None} with pytest.raises(r.BuildError): - x = adapter.build("endp", params) - assert not x + adapter.build("endp", params) params = {"flub": "x", "flop": None} url = adapter.build("endp", params) assert "flop" not in url @@ -992,7 +1041,16 @@ def test_external_building_with_port_bind_to_environ_wrong_servername(): with pytest.warns(UserWarning): adapter = map.bind_to_environ(environ, server_name="example.org") - assert adapter.subdomain == "" + + assert adapter.subdomain == "" + + +def test_bind_long_idna_name_with_port(): + map = r.Map([r.Rule("/", endpoint="index")]) + adapter = map.bind("🐍" + "a" * 52 + ":8443") + name, _, port = adapter.server_name.partition(":") + assert len(name) == 63 + assert port == "8443" def test_converter_parser(): @@ -1071,18 +1129,6 @@ def test_double_defaults(prefix): assert a.build("x", {"bar": True}) == f"{prefix}/bar/" -def test_building_bytes(): - m = r.Map( - [ - r.Rule("/", endpoint="a"), - r.Rule("/", defaults={"b": b"\x01\x02\x03"}, endpoint="b"), - ] - ) - a = m.bind("example.org", "/") - assert a.build("a", {"a": b"\x01\x02\x03"}) == "/%01%02%03" - assert a.build("b") == "/%01%02%03" - - def test_host_matching(): m = r.Map( [ @@ -1434,6 +1480,9 @@ def test_strict_slashes_false(): [ r.Rule("/path1", endpoint="leaf_path", strict_slashes=False), r.Rule("/path2/", endpoint="branch_path", strict_slashes=False), + r.Rule( + "/", endpoint="leaf_path_converter", strict_slashes=False + ), ], ) @@ -1443,12 +1492,19 @@ def test_strict_slashes_false(): assert adapter.match("/path1/", method="GET") == ("leaf_path", {}) assert adapter.match("/path2", method="GET") == ("branch_path", {}) assert adapter.match("/path2/", method="GET") == ("branch_path", {}) + assert adapter.match("/any", method="GET") == ( + "leaf_path_converter", + {"path": "any"}, + ) + assert adapter.match("/any/", method="GET") == ( + "leaf_path_converter", + {"path": "any/"}, + ) def test_invalid_rule(): with pytest.raises(ValueError): - map_ = r.Map([r.Rule("/", endpoint="test")]) - map_.bind("localhost") + r.Map([r.Rule("/", endpoint="test")]) def test_multiple_converters_per_part(): diff --git a/tests/test_security.py b/tests/test_security.py index 3e797fc..6fad089 100644 --- a/tests/test_security.py +++ b/tests/test_security.py @@ -1,5 +1,6 @@ import os import posixpath +import sys import pytest @@ -8,25 +9,42 @@ from werkzeug.security import generate_password_hash from werkzeug.security import safe_join -def test_password_hashing(): - hash0 = generate_password_hash("default") - assert check_password_hash(hash0, "default") - assert hash0.startswith("pbkdf2:sha256:260000$") +def test_default_password_method(): + value = generate_password_hash("secret") + assert value.startswith("scrypt:") - hash1 = generate_password_hash("default", "sha1") - hash2 = generate_password_hash("default", method="sha1") + +@pytest.mark.xfail( + sys.implementation.name == "pypy", reason="scrypt unavailable on pypy" +) +def test_scrypt(): + value = generate_password_hash("secret", method="scrypt") + assert check_password_hash(value, "secret") + assert value.startswith("scrypt:32768:8:1$") + + +def test_pbkdf2(): + value = generate_password_hash("secret", method="pbkdf2") + assert check_password_hash(value, "secret") + assert value.startswith("pbkdf2:sha256:600000$") + + +def test_salted_hashes(): + hash1 = generate_password_hash("secret") + hash2 = generate_password_hash("secret") assert hash1 != hash2 - assert check_password_hash(hash1, "default") - assert check_password_hash(hash2, "default") - assert hash1.startswith("sha1$") - assert hash2.startswith("sha1$") + assert check_password_hash(hash1, "secret") + assert check_password_hash(hash2, "secret") + +def test_require_salt(): with pytest.raises(ValueError): - generate_password_hash("default", "sha1", salt_length=0) + generate_password_hash("secret", salt_length=0) - fakehash = generate_password_hash("default", method="plain") - assert fakehash == "plain$$default" - assert check_password_hash(fakehash, "default") + +def test_invalid_method(): + with pytest.raises(ValueError, match="Invalid hash method"): + generate_password_hash("secret", "sha256") def test_safe_join(): diff --git a/tests/test_send_file.py b/tests/test_send_file.py index fc4299a..4aa69f2 100644 --- a/tests/test_send_file.py +++ b/tests/test_send_file.py @@ -107,6 +107,9 @@ def test_object_attachment_requires_name(): ("Vögel.txt", "Vogel.txt", "V%C3%B6gel.txt"), # ":/" are not safe in filename* value ("те:/ст", '":/"', "%D1%82%D0%B5%3A%2F%D1%81%D1%82"), + # general test of extended parameter (non-quoted) + ("(тест.txt", '"(.txt"', "%28%D1%82%D0%B5%D1%81%D1%82.txt"), + ("(test.txt", '"(test.txt"', None), ), ) def test_non_ascii_name(name, ascii, utf8): diff --git a/tests/test_serving.py b/tests/test_serving.py index ecdb15a..4abc755 100644 --- a/tests/test_serving.py +++ b/tests/test_serving.py @@ -7,13 +7,18 @@ import ssl import sys from io import BytesIO from pathlib import Path +from unittest.mock import patch import pytest +from watchdog.events import EVENT_TYPE_MODIFIED +from watchdog.events import EVENT_TYPE_OPENED +from watchdog.events import FileModifiedEvent from werkzeug import run_simple from werkzeug._reloader import _find_stat_paths from werkzeug._reloader import _find_watchdog_paths from werkzeug._reloader import _get_args_for_reloading +from werkzeug._reloader import WatchdogReloaderLoop from werkzeug.datastructures import FileStorage from werkzeug.serving import make_ssl_devcert from werkzeug.test import stream_encode_multipart @@ -115,6 +120,23 @@ def test_reloader_sys_path(tmp_path, dev_server, reloader_type): assert client.request().status == 200 +@patch.object(WatchdogReloaderLoop, "trigger_reload") +def test_watchdog_reloader_ignores_opened(mock_trigger_reload): + reloader = WatchdogReloaderLoop() + modified_event = FileModifiedEvent("") + modified_event.event_type = EVENT_TYPE_MODIFIED + reloader.event_handler.on_any_event(modified_event) + mock_trigger_reload.assert_called_once() + + reloader.trigger_reload.reset_mock() + + opened_event = FileModifiedEvent("") + opened_event.event_type = EVENT_TYPE_OPENED + reloader.event_handler.on_any_event(opened_event) + reloader.trigger_reload.assert_not_called() + + +@pytest.mark.skipif(sys.version_info >= (3, 10), reason="not needed on >= 3.10") def test_windows_get_args_for_reloading(monkeypatch, tmp_path): argv = [str(tmp_path / "test.exe"), "run"] monkeypatch.setattr("sys.executable", str(tmp_path / "python.exe")) @@ -125,6 +147,20 @@ def test_windows_get_args_for_reloading(monkeypatch, tmp_path): assert rv == argv +@pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning") +@pytest.mark.parametrize("find", [_find_stat_paths, _find_watchdog_paths]) +def test_exclude_patterns(find): + # Select a path to exclude from the unfiltered list, assert that it is present and + # then gets excluded. + paths = find(set(), set()) + path_to_exclude = next(iter(paths)) + assert any(p.startswith(path_to_exclude) for p in paths) + + # Those paths should be excluded due to the pattern. + paths = find(set(), {f"{path_to_exclude}*"}) + assert not any(p.startswith(path_to_exclude) for p in paths) + + @pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning") @pytest.mark.dev_server def test_wrong_protocol(standard_app): @@ -244,6 +280,7 @@ def test_multiline_header_folding(standard_app): @pytest.mark.parametrize("endpoint", ["", "crash"]) +@pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning") @pytest.mark.dev_server def test_streaming_close_response(dev_server, endpoint): """When using HTTP/1.0, chunked encoding is not supported. Fall @@ -255,6 +292,7 @@ def test_streaming_close_response(dev_server, endpoint): assert r.data == "".join(str(x) + "\n" for x in range(5)).encode() +@pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning") @pytest.mark.dev_server def test_streaming_chunked_response(dev_server): """When using HTTP/1.1, use Transfer-Encoding: chunked for streamed diff --git a/tests/test_test.py b/tests/test_test.py index 02d637e..c7f21fa 100644 --- a/tests/test_test.py +++ b/tests/test_test.py @@ -10,7 +10,6 @@ from werkzeug.datastructures import FileStorage from werkzeug.datastructures import Headers from werkzeug.datastructures import MultiDict from werkzeug.formparser import parse_form_data -from werkzeug.http import parse_authorization_header from werkzeug.test import Client from werkzeug.test import ClientRedirectError from werkzeug.test import create_environ @@ -74,7 +73,7 @@ def multi_value_post_app(environ, start_response): def test_cookie_forging(): c = Client(cookie_app) - c.set_cookie("localhost", "foo", "bar") + c.set_cookie("foo", "bar") response = c.open() assert response.text == "foo=bar" @@ -88,7 +87,7 @@ def test_set_cookie_app(): def test_cookiejar_stores_cookie(): c = Client(cookie_app) c.open() - assert "test" in c.cookie_jar._cookies["localhost.local"]["/"] + assert c.get_cookie("test") is not None def test_no_initial_cookie(): @@ -118,6 +117,25 @@ def test_cookie_for_different_path(): assert response.text == "test=test" +def test_cookie_default_path() -> None: + """When no path is set for a cookie, the default uses everything up to but not + including the first slash. + """ + + @Request.application + def app(request: Request) -> Response: + r = Response() + r.set_cookie("k", "v", path=None) + return r + + c = Client(app) + c.get("/nested/leaf") + assert c.get_cookie("k") is None + assert c.get_cookie("k", path="/nested") is not None + c.get("/nested/dir/") + assert c.get_cookie("k", path="/nested/dir") is not None + + def test_environ_builder_basics(): b = EnvironBuilder() assert b.content_type is None @@ -284,9 +302,8 @@ def test_environ_builder_content_type(): def test_basic_auth(): builder = EnvironBuilder(auth=("username", "password")) request = builder.get_request() - auth = parse_authorization_header(request.headers["Authorization"]) - assert auth.username == "username" - assert auth.password == "password" + assert request.authorization.username == "username" + assert request.authorization.password == "password" def test_auth_object(): @@ -340,6 +357,23 @@ def test_environ_builder_unicode_file_mix(): files["f"].close() +def test_environ_builder_empty_file(): + f = FileStorage(BytesIO(rb""), "empty.txt") + d = MultiDict(dict(f=f, s="")) + stream, length, boundary = stream_encode_multipart(d) + _, form, files = parse_form_data( + { + "wsgi.input": stream, + "CONTENT_LENGTH": str(length), + "CONTENT_TYPE": f'multipart/form-data; boundary="{boundary}"', + } + ) + assert form["s"] == "" + assert files["f"].read() == rb"" + stream.close() + files["f"].close() + + def test_create_environ(): env = create_environ("/foo?bar=baz", "http://example.org/") expected = { @@ -392,7 +426,7 @@ def test_file_closing(): class SpecialInput: def read(self, size): - return "" + return b"" def close(self): closed.append(self) @@ -752,8 +786,8 @@ def test_multiple_cookies(): @Request.application def test_app(request): response = Response(repr(sorted(request.cookies.items()))) - response.set_cookie("test1", b"foo") - response.set_cookie("test2", b"bar") + response.set_cookie("test1", "foo") + response.set_cookie("test2", "bar") return response client = Client(test_app) diff --git a/tests/test_urls.py b/tests/test_urls.py index a409709..fdaa913 100644 --- a/tests/test_urls.py +++ b/tests/test_urls.py @@ -1,240 +1,26 @@ -import io - import pytest from werkzeug import urls -from werkzeug.datastructures import OrderedMultiDict - - -def test_parsing(): - url = urls.url_parse("http://anon:hunter2@[2001:db8:0:1]:80/a/b/c") - assert url.netloc == "anon:hunter2@[2001:db8:0:1]:80" - assert url.username == "anon" - assert url.password == "hunter2" - assert url.port == 80 - assert url.ascii_host == "2001:db8:0:1" - - assert url.get_file_location() == (None, None) # no file scheme - - -@pytest.mark.parametrize("implicit_format", (True, False)) -@pytest.mark.parametrize("localhost", ("127.0.0.1", "::1", "localhost")) -def test_fileurl_parsing_windows(implicit_format, localhost, monkeypatch): - if implicit_format: - pathformat = None - monkeypatch.setattr("os.name", "nt") - else: - pathformat = "windows" - monkeypatch.delattr("os.name") # just to make sure it won't get used - - url = urls.url_parse("file:///C:/Documents and Settings/Foobar/stuff.txt") - assert url.netloc == "" - assert url.scheme == "file" - assert url.get_file_location(pathformat) == ( - None, - r"C:\Documents and Settings\Foobar\stuff.txt", - ) - - url = urls.url_parse("file://///server.tld/file.txt") - assert url.get_file_location(pathformat) == ("server.tld", r"file.txt") - - url = urls.url_parse("file://///server.tld") - assert url.get_file_location(pathformat) == ("server.tld", "") - - url = urls.url_parse(f"file://///{localhost}") - assert url.get_file_location(pathformat) == (None, "") - - url = urls.url_parse(f"file://///{localhost}/file.txt") - assert url.get_file_location(pathformat) == (None, r"file.txt") - - -def test_replace(): - url = urls.url_parse("http://de.wikipedia.org/wiki/Troll") - assert url.replace(query="foo=bar") == urls.url_parse( - "http://de.wikipedia.org/wiki/Troll?foo=bar" - ) - assert url.replace(scheme="https") == urls.url_parse( - "https://de.wikipedia.org/wiki/Troll" - ) - - -def test_quoting(): - assert urls.url_quote("\xf6\xe4\xfc") == "%C3%B6%C3%A4%C3%BC" - assert urls.url_unquote(urls.url_quote('#%="\xf6')) == '#%="\xf6' - assert urls.url_quote_plus("foo bar") == "foo+bar" - assert urls.url_unquote_plus("foo+bar") == "foo bar" - assert urls.url_quote_plus("foo+bar") == "foo%2Bbar" - assert urls.url_unquote_plus("foo%2Bbar") == "foo+bar" - assert urls.url_encode({b"a": None, b"b": b"foo bar"}) == "b=foo+bar" - assert urls.url_encode({"a": None, "b": "foo bar"}) == "b=foo+bar" - assert ( - urls.url_fix("http://de.wikipedia.org/wiki/Elf (Begriffsklärung)") - == "http://de.wikipedia.org/wiki/Elf%20(Begriffskl%C3%A4rung)" - ) - assert urls.url_quote_plus(42) == "42" - assert urls.url_quote(b"\xff") == "%FF" - - -def test_bytes_unquoting(): - assert ( - urls.url_unquote(urls.url_quote('#%="\xf6', charset="latin1"), charset=None) - == b'#%="\xf6' - ) - - -def test_url_decoding(): - x = urls.url_decode(b"foo=42&bar=23&uni=H%C3%A4nsel") - assert x["foo"] == "42" - assert x["bar"] == "23" - assert x["uni"] == "Hänsel" - - x = urls.url_decode(b"foo=42;bar=23;uni=H%C3%A4nsel", separator=b";") - assert x["foo"] == "42" - assert x["bar"] == "23" - assert x["uni"] == "Hänsel" - - x = urls.url_decode(b"%C3%9Ch=H%C3%A4nsel") - assert x["Üh"] == "Hänsel" - - -def test_url_bytes_decoding(): - x = urls.url_decode(b"foo=42&bar=23&uni=H%C3%A4nsel", charset=None) - assert x[b"foo"] == b"42" - assert x[b"bar"] == b"23" - assert x[b"uni"] == "Hänsel".encode() - - -def test_stream_decoding_string_fails(): - pytest.raises(TypeError, urls.url_decode_stream, "testing") - - -def test_url_encoding(): - assert urls.url_encode({"foo": "bar 45"}) == "foo=bar+45" - d = {"foo": 1, "bar": 23, "blah": "Hänsel"} - assert urls.url_encode(d, sort=True) == "bar=23&blah=H%C3%A4nsel&foo=1" - assert ( - urls.url_encode(d, sort=True, separator=";") == "bar=23;blah=H%C3%A4nsel;foo=1" - ) - - -def test_sorted_url_encode(): - assert ( - urls.url_encode( - {"a": 42, "b": 23, 1: 1, 2: 2}, sort=True, key=lambda i: str(i[0]) - ) - == "1=1&2=2&a=42&b=23" - ) - assert ( - urls.url_encode( - {"A": 1, "a": 2, "B": 3, "b": 4}, - sort=True, - key=lambda x: x[0].lower() + x[0], - ) - == "A=1&a=2&B=3&b=4" - ) - - -def test_streamed_url_encoding(): - out = io.StringIO() - urls.url_encode_stream({"foo": "bar 45"}, out) - assert out.getvalue() == "foo=bar+45" - - d = {"foo": 1, "bar": 23, "blah": "Hänsel"} - out = io.StringIO() - urls.url_encode_stream(d, out, sort=True) - assert out.getvalue() == "bar=23&blah=H%C3%A4nsel&foo=1" - out = io.StringIO() - urls.url_encode_stream(d, out, sort=True, separator=";") - assert out.getvalue() == "bar=23;blah=H%C3%A4nsel;foo=1" - - gen = urls.url_encode_stream(d, sort=True) - assert next(gen) == "bar=23" - assert next(gen) == "blah=H%C3%A4nsel" - assert next(gen) == "foo=1" - pytest.raises(StopIteration, lambda: next(gen)) - - -def test_url_fixing(): - x = urls.url_fix("http://de.wikipedia.org/wiki/Elf (Begriffskl\xe4rung)") - assert x == "http://de.wikipedia.org/wiki/Elf%20(Begriffskl%C3%A4rung)" - - x = urls.url_fix("http://just.a.test/$-_.+!*'(),") - assert x == "http://just.a.test/$-_.+!*'()," - - x = urls.url_fix("http://höhöhö.at/höhöhö/hähähä") - assert x == r"http://xn--hhh-snabb.at/h%C3%B6h%C3%B6h%C3%B6/h%C3%A4h%C3%A4h%C3%A4" - - -def test_url_fixing_filepaths(): - x = urls.url_fix(r"file://C:\Users\Administrator\My Documents\ÑÈáÇíí") - assert x == ( - r"file:///C%3A/Users/Administrator/My%20Documents/" - r"%C3%91%C3%88%C3%A1%C3%87%C3%AD%C3%AD" - ) - - a = urls.url_fix(r"file:/C:/") - b = urls.url_fix(r"file://C:/") - c = urls.url_fix(r"file:///C:/") - assert a == b == c == r"file:///C%3A/" - - x = urls.url_fix(r"file://host/sub/path") - assert x == r"file://host/sub/path" - - x = urls.url_fix(r"file:///") - assert x == r"file:///" - - -def test_url_fixing_qs(): - x = urls.url_fix(b"http://example.com/?foo=%2f%2f") - assert x == "http://example.com/?foo=%2f%2f" - - x = urls.url_fix( - "http://acronyms.thefreedictionary.com/" - "Algebraic+Methods+of+Solving+the+Schr%C3%B6dinger+Equation" - ) - assert x == ( - "http://acronyms.thefreedictionary.com/" - "Algebraic+Methods+of+Solving+the+Schr%C3%B6dinger+Equation" - ) def test_iri_support(): assert urls.uri_to_iri("http://xn--n3h.net/") == "http://\u2603.net/" - assert ( - urls.uri_to_iri(b"http://%C3%BCser:p%C3%A4ssword@xn--n3h.net/p%C3%A5th") - == "http://\xfcser:p\xe4ssword@\u2603.net/p\xe5th" - ) assert urls.iri_to_uri("http://☃.net/") == "http://xn--n3h.net/" assert ( urls.iri_to_uri("http://üser:pässword@☃.net/påth") == "http://%C3%BCser:p%C3%A4ssword@xn--n3h.net/p%C3%A5th" ) - assert ( urls.uri_to_iri("http://test.com/%3Fmeh?foo=%26%2F") - == "http://test.com/%3Fmeh?foo=%26%2F" + == "http://test.com/%3Fmeh?foo=%26/" ) - - # this should work as well, might break on 2.4 because of a broken - # idna codec - assert urls.uri_to_iri(b"/foo") == "/foo" assert urls.iri_to_uri("/foo") == "/foo" - assert ( urls.iri_to_uri("http://föö.com:8080/bam/baz") == "http://xn--f-1gaa.com:8080/bam/baz" ) -def test_iri_safe_conversion(): - assert urls.iri_to_uri("magnet:?foo=bar") == "magnet:?foo=bar" - assert urls.iri_to_uri("itms-service://?foo=bar") == "itms-service:?foo=bar" - assert ( - urls.iri_to_uri("itms-service://?foo=bar", safe_conversion=True) - == "itms-service://?foo=bar" - ) - - def test_iri_safe_quoting(): uri = "http://xn--f-1gaa.com/%2F%25?q=%C3%B6&x=%3D%25#%25" iri = "http://föö.com/%2F%25?q=ö&x=%3D%25#%25" @@ -242,83 +28,11 @@ def test_iri_safe_quoting(): assert urls.iri_to_uri(urls.uri_to_iri(uri)) == uri -def test_ordered_multidict_encoding(): - d = OrderedMultiDict() - d.add("foo", 1) - d.add("foo", 2) - d.add("foo", 3) - d.add("bar", 0) - d.add("foo", 4) - assert urls.url_encode(d) == "foo=1&foo=2&foo=3&bar=0&foo=4" - - -def test_multidict_encoding(): - d = OrderedMultiDict() - d.add("2013-10-10T23:26:05.657975+0000", "2013-10-10T23:26:05.657975+0000") - assert ( - urls.url_encode(d) - == "2013-10-10T23%3A26%3A05.657975%2B0000=2013-10-10T23%3A26%3A05.657975%2B0000" - ) - - -def test_url_unquote_plus_unicode(): - # was broken in 0.6 - assert urls.url_unquote_plus("\x6d") == "\x6d" - - def test_quoting_of_local_urls(): rv = urls.iri_to_uri("/foo\x8f") assert rv == "/foo%C2%8F" -def test_url_attributes(): - rv = urls.url_parse("http://foo%3a:bar%3a@[::1]:80/123?x=y#frag") - assert rv.scheme == "http" - assert rv.auth == "foo%3a:bar%3a" - assert rv.username == "foo:" - assert rv.password == "bar:" - assert rv.raw_username == "foo%3a" - assert rv.raw_password == "bar%3a" - assert rv.host == "::1" - assert rv.port == 80 - assert rv.path == "/123" - assert rv.query == "x=y" - assert rv.fragment == "frag" - - rv = urls.url_parse("http://\N{SNOWMAN}.com/") - assert rv.host == "\N{SNOWMAN}.com" - assert rv.ascii_host == "xn--n3h.com" - - -def test_url_attributes_bytes(): - rv = urls.url_parse(b"http://foo%3a:bar%3a@[::1]:80/123?x=y#frag") - assert rv.scheme == b"http" - assert rv.auth == b"foo%3a:bar%3a" - assert rv.username == "foo:" - assert rv.password == "bar:" - assert rv.raw_username == b"foo%3a" - assert rv.raw_password == b"bar%3a" - assert rv.host == b"::1" - assert rv.port == 80 - assert rv.path == b"/123" - assert rv.query == b"x=y" - assert rv.fragment == b"frag" - - -def test_url_joining(): - assert urls.url_join("/foo", "/bar") == "/bar" - assert urls.url_join("http://example.com/foo", "/bar") == "http://example.com/bar" - assert urls.url_join("file:///tmp/", "test.html") == "file:///tmp/test.html" - assert urls.url_join("file:///tmp/x", "test.html") == "file:///tmp/test.html" - assert urls.url_join("file:///tmp/x", "../../../x.html") == "file:///x.html" - - -def test_partial_unencoded_decode(): - ref = "foo=정상처리".encode("euc-kr") - x = urls.url_decode(ref, charset="euc-kr") - assert x["foo"] == "정상처리" - - def test_iri_to_uri_idempotence_ascii_only(): uri = "http://www.idempoten.ce" uri = urls.iri_to_uri(uri) @@ -355,31 +69,32 @@ def test_uri_to_iri_to_uri(): assert urls.iri_to_uri(iri) == uri -def test_uri_iri_normalization(): - uri = "http://xn--f-rgao.com/%E2%98%90/fred?utf8=%E2%9C%93" - iri = "http://föñ.com/\N{BALLOT BOX}/fred?utf8=\u2713" - - tests = [ +@pytest.mark.parametrize( + "value", + [ "http://föñ.com/\N{BALLOT BOX}/fred?utf8=\u2713", "http://xn--f-rgao.com/\u2610/fred?utf8=\N{CHECK MARK}", - b"http://xn--f-rgao.com/%E2%98%90/fred?utf8=%E2%9C%93", + "http://xn--f-rgao.com/%E2%98%90/fred?utf8=%E2%9C%93", "http://xn--f-rgao.com/%E2%98%90/fred?utf8=%E2%9C%93", "http://föñ.com/\u2610/fred?utf8=%E2%9C%93", - b"http://xn--f-rgao.com/\xe2\x98\x90/fred?utf8=\xe2\x9c\x93", - ] - - for test in tests: - assert urls.uri_to_iri(test) == iri - assert urls.iri_to_uri(test) == uri - assert urls.uri_to_iri(urls.iri_to_uri(test)) == iri - assert urls.iri_to_uri(urls.uri_to_iri(test)) == uri - assert urls.uri_to_iri(urls.uri_to_iri(test)) == iri - assert urls.iri_to_uri(urls.iri_to_uri(test)) == uri + ], +) +def test_uri_iri_normalization(value): + uri = "http://xn--f-rgao.com/%E2%98%90/fred?utf8=%E2%9C%93" + iri = "http://föñ.com/\N{BALLOT BOX}/fred?utf8=\u2713" + assert urls.uri_to_iri(value) == iri + assert urls.iri_to_uri(value) == uri + assert urls.uri_to_iri(urls.iri_to_uri(value)) == iri + assert urls.iri_to_uri(urls.uri_to_iri(value)) == uri + assert urls.uri_to_iri(urls.uri_to_iri(value)) == iri + assert urls.iri_to_uri(urls.iri_to_uri(value)) == uri def test_uri_to_iri_dont_unquote_space(): assert urls.uri_to_iri("abc%20def") == "abc%20def" -def test_iri_to_uri_dont_quote_reserved(): - assert urls.iri_to_uri("/path[bracket]?(paren)") == "/path[bracket]?(paren)" +def test_iri_to_uri_dont_quote_valid_code_points(): + # [] are not valid URL code points according to WhatWG URL Standard + # https://url.spec.whatwg.org/#url-code-points + assert urls.iri_to_uri("/path[bracket]?(paren)") == "/path%5Bbracket%5D?(paren)" diff --git a/tests/test_utils.py b/tests/test_utils.py index ed8d8d0..b7f1bcb 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import inspect from datetime import datetime @@ -9,48 +11,32 @@ from werkzeug.datastructures import Headers from werkzeug.http import http_date from werkzeug.http import parse_date from werkzeug.test import Client +from werkzeug.test import EnvironBuilder from werkzeug.wrappers import Response -def test_redirect(): - resp = utils.redirect("/füübär") - assert resp.headers["Location"] == "/f%C3%BC%C3%BCb%C3%A4r" - assert resp.status_code == 302 - assert resp.get_data() == ( - b"\n" - b"\n" - b"Redirecting...\n" - b"

Redirecting...

\n" - b"

You should be redirected automatically to the target URL: " - b'/f\xc3\xbc\xc3\xbcb\xc3\xa4r. ' - b"If not, click the link.\n" - ) +@pytest.mark.parametrize( + ("url", "code", "expect"), + [ + ("http://example.com", None, "http://example.com"), + ("/füübär", 305, "/f%C3%BC%C3%BCb%C3%A4r"), + ("http://☃.example.com/", 307, "http://xn--n3h.example.com/"), + ("itms-services://?url=abc", None, "itms-services://?url=abc"), + ], +) +def test_redirect(url: str, code: int | None, expect: str) -> None: + environ = EnvironBuilder().get_environ() - resp = utils.redirect("http://☃.net/", 307) - assert resp.headers["Location"] == "http://xn--n3h.net/" - assert resp.status_code == 307 - assert resp.get_data() == ( - b"\n" - b"\n" - b"Redirecting...\n" - b"

Redirecting...

\n" - b"

You should be redirected automatically to the target URL: " - b'http://\xe2\x98\x83.net/. ' - b"If not, click the link.\n" - ) + if code is None: + resp = utils.redirect(url) + assert resp.status_code == 302 + else: + resp = utils.redirect(url, code) + assert resp.status_code == code - resp = utils.redirect("http://example.com/", 305) - assert resp.headers["Location"] == "http://example.com/" - assert resp.status_code == 305 - assert resp.get_data() == ( - b"\n" - b"\n" - b"Redirecting...\n" - b"

Redirecting...

\n" - b"

You should be redirected automatically to the target URL: " - b'http://example.com/. ' - b"If not, click the link.\n" - ) + assert resp.headers["Location"] == url + assert resp.get_wsgi_headers(environ)["Location"] == expect + assert resp.get_data(as_text=True).count(url) == 2 def test_redirect_xss(): diff --git a/tests/test_wrappers.py b/tests/test_wrappers.py index b769a38..8a91aef 100644 --- a/tests/test_wrappers.py +++ b/tests/test_wrappers.py @@ -20,9 +20,11 @@ from werkzeug.datastructures import ImmutableOrderedMultiDict from werkzeug.datastructures import LanguageAccept from werkzeug.datastructures import MIMEAccept from werkzeug.datastructures import MultiDict +from werkzeug.datastructures import WWWAuthenticate from werkzeug.exceptions import BadRequest from werkzeug.exceptions import RequestedRangeNotSatisfiable from werkzeug.exceptions import SecurityError +from werkzeug.exceptions import UnsupportedMediaType from werkzeug.http import COEP from werkzeug.http import COOP from werkzeug.http import generate_etag @@ -136,11 +138,12 @@ def test_url_request_descriptors(): def test_url_request_descriptors_query_quoting(): - next = "http%3A%2F%2Fwww.example.com%2F%3Fnext%3D%2Fbaz%23my%3Dhash" - req = wrappers.Request.from_values(f"/bar?next={next}", "http://example.com/") + quoted = "http%3A%2F%2Fwww.example.com%2F%3Fnext%3D%2Fbaz%23my%3Dhash" + unquoted = "http://www.example.com/?next%3D/baz%23my%3Dhash" + req = wrappers.Request.from_values(f"/bar?next={quoted}", "http://example.com/") assert req.path == "/bar" - assert req.full_path == f"/bar?next={next}" - assert req.url == f"http://example.com/bar?next={next}" + assert req.full_path == f"/bar?next={quoted}" + assert req.url == f"http://example.com/bar?next={unquoted}" def test_url_request_descriptors_hosts(): @@ -349,13 +352,6 @@ def test_response_init_status_empty_string(): assert "Empty status argument" in str(info.value) -def test_response_init_status_tuple(): - with pytest.raises(TypeError) as info: - wrappers.Response(None, tuple()) - - assert "Invalid status argument" in str(info.value) - - def test_type_forcing(): def wsgi_application(environ, start_response): start_response("200 OK", [("Content-Type", "text/html")]) @@ -686,27 +682,26 @@ def test_etag_response_freezing(): def test_authenticate(): resp = wrappers.Response() - resp.www_authenticate.type = "basic" resp.www_authenticate.realm = "Testing" - assert resp.headers["WWW-Authenticate"] == 'Basic realm="Testing"' - resp.www_authenticate.realm = None - resp.www_authenticate.type = None + assert resp.headers["WWW-Authenticate"] == "Basic realm=Testing" + del resp.www_authenticate assert "WWW-Authenticate" not in resp.headers def test_authenticate_quoted_qop(): # Example taken from https://github.com/pallets/werkzeug/issues/633 resp = wrappers.Response() - resp.www_authenticate.set_digest("REALM", "NONCE", qop=("auth", "auth-int")) + resp.www_authenticate = WWWAuthenticate( + "digest", {"realm": "REALM", "nonce": "NONCE", "qop": "auth, auth-int"} + ) - actual = set(f"{resp.headers['WWW-Authenticate']},".split()) - expected = set('Digest nonce="NONCE", realm="REALM", qop="auth, auth-int",'.split()) + actual = resp.headers["WWW-Authenticate"] + expected = 'Digest realm="REALM", nonce="NONCE", qop="auth, auth-int"' assert actual == expected - resp.www_authenticate.set_digest("REALM", "NONCE", qop=("auth",)) - - actual = set(f"{resp.headers['WWW-Authenticate']},".split()) - expected = set('Digest nonce="NONCE", realm="REALM", qop="auth",'.split()) + resp.www_authenticate.parameters["qop"] = "auth" + actual = resp.headers["WWW-Authenticate"] + expected = 'Digest realm="REALM", nonce="NONCE", qop="auth"' assert actual == expected @@ -875,12 +870,6 @@ def test_file_closing_with(): assert foo.closed is True -def test_url_charset_reflection(): - req = wrappers.Request.from_values() - req.charset = "utf-7" - assert req.url_charset == "utf-7" - - def test_response_streamed(): r = wrappers.Response() assert not r.is_streamed @@ -1206,14 +1195,6 @@ def test_malformed_204_response_has_no_content_length(): assert b"".join(app_iter) == b"" # ensure data will not be sent -def test_modified_url_encoding(): - class ModifiedRequest(wrappers.Request): - url_charset = "euc-kr" - - req = ModifiedRequest.from_values(query_string={"foo": "정상처리"}, charset="euc-kr") - assert req.args["foo"] == "정상처리" - - def test_request_method_case_sensitivity(): req = wrappers.Request( {"REQUEST_METHOD": "get", "SERVER_NAME": "eggs", "SERVER_PORT": "80"} @@ -1350,7 +1331,7 @@ class TestJSON: value = [1, 2, 3] request = wrappers.Request.from_values(json=value, content_type="text/plain") - with pytest.raises(BadRequest): + with pytest.raises(UnsupportedMediaType): request.get_json() assert request.get_json(silent=True) is None diff --git a/tests/test_wsgi.py b/tests/test_wsgi.py index b0f71bc..7f4d2e9 100644 --- a/tests/test_wsgi.py +++ b/tests/test_wsgi.py @@ -1,6 +1,9 @@ +from __future__ import annotations + import io import json import os +import typing as t import pytest @@ -84,7 +87,6 @@ def test_responder(): def test_path_info_and_script_name_fetching(): env = create_environ("/\N{SNOWMAN}", "http://example.com/\N{COMET}/") assert wsgi.get_path_info(env) == "/\N{SNOWMAN}" - assert wsgi.get_path_info(env, charset=None) == "/\N{SNOWMAN}".encode() def test_limited_stream(): @@ -117,11 +119,10 @@ def test_limited_stream(): stream = wsgi.LimitedStream(io_, 9) assert stream.readlines() == [b"123456\n", b"ab"] - io_ = io.BytesIO(b"123456\nabcdefg") + io_ = io.BytesIO(b"123\n456\nabcdefg") stream = wsgi.LimitedStream(io_, 9) - assert stream.readlines(2) == [b"12"] - assert stream.readlines(2) == [b"34"] - assert stream.readlines() == [b"56\n", b"ab"] + assert stream.readlines(2) == [b"123\n"] + assert stream.readlines() == [b"456\n", b"a"] io_ = io.BytesIO(b"123456\nabcdefg") stream = wsgi.LimitedStream(io_, 9) @@ -146,13 +147,8 @@ def test_limited_stream(): stream = wsgi.LimitedStream(io_, 0) assert stream.read(-1) == b"" - io_ = io.StringIO("123456") - stream = wsgi.LimitedStream(io_, 0) - assert stream.read(-1) == "" - - io_ = io.StringIO("123\n456\n") - stream = wsgi.LimitedStream(io_, 8) - assert list(stream) == ["123\n", "456\n"] + stream = wsgi.LimitedStream(io.BytesIO(b"123\n456\n"), 8) + assert list(stream) == [b"123\n", b"456\n"] def test_limited_stream_json_load(): @@ -165,21 +161,63 @@ def test_limited_stream_json_load(): def test_limited_stream_disconnection(): - io_ = io.BytesIO(b"A bit of content") - - # disconnect detection on out of bytes - stream = wsgi.LimitedStream(io_, 255) + # disconnect because stream returns zero bytes + stream = wsgi.LimitedStream(io.BytesIO(), 255) with pytest.raises(ClientDisconnected): stream.read() - # disconnect detection because file close - io_ = io.BytesIO(b"x" * 255) - io_.close() - stream = wsgi.LimitedStream(io_, 255) + # disconnect because stream is closed + data = io.BytesIO(b"x" * 255) + data.close() + stream = wsgi.LimitedStream(data, 255) + with pytest.raises(ClientDisconnected): stream.read() +def test_limited_stream_read_with_raw_io(): + class OneByteStream(t.BinaryIO): + def __init__(self, buf: bytes) -> None: + self.buf = buf + self.pos = 0 + + def read(self, size: int | None = None) -> bytes: + """Return one byte at a time regardless of requested size.""" + + if size is None or size == -1: + raise ValueError("expected read to be called with specific limit") + + if size == 0 or len(self.buf) < self.pos: + return b"" + + b = self.buf[self.pos : self.pos + 1] + self.pos += 1 + return b + + stream = wsgi.LimitedStream(OneByteStream(b"foo"), 4) + assert stream.read(5) == b"f" + assert stream.read(5) == b"o" + assert stream.read(5) == b"o" + + # The stream has fewer bytes (3) than the limit (4), therefore the read returns 0 + # bytes before the limit is reached. + with pytest.raises(ClientDisconnected): + stream.read(5) + + stream = wsgi.LimitedStream(OneByteStream(b"foo123"), 3) + assert stream.read(5) == b"f" + assert stream.read(5) == b"o" + assert stream.read(5) == b"o" + # The limit was reached, therefore the wrapper is exhausted, not disconnected. + assert stream.read(5) == b"" + + stream = wsgi.LimitedStream(OneByteStream(b"foo"), 3) + assert stream.read() == b"foo" + + stream = wsgi.LimitedStream(OneByteStream(b"foo"), 2) + assert stream.read() == b"fo" + + def test_get_host_fallback(): assert ( wsgi.get_host( @@ -218,123 +256,6 @@ def test_get_current_url_invalid_utf8(): assert rv == "http://localhost/?foo=bar&baz=blah&meh=%CF" -def test_multi_part_line_breaks(): - data = "abcdef\r\nghijkl\r\nmnopqrstuvwxyz\r\nABCDEFGHIJK" - test_stream = io.StringIO(data) - lines = list(wsgi.make_line_iter(test_stream, limit=len(data), buffer_size=16)) - assert lines == ["abcdef\r\n", "ghijkl\r\n", "mnopqrstuvwxyz\r\n", "ABCDEFGHIJK"] - - data = "abc\r\nThis line is broken by the buffer length.\r\nFoo bar baz" - test_stream = io.StringIO(data) - lines = list(wsgi.make_line_iter(test_stream, limit=len(data), buffer_size=24)) - assert lines == [ - "abc\r\n", - "This line is broken by the buffer length.\r\n", - "Foo bar baz", - ] - - -def test_multi_part_line_breaks_bytes(): - data = b"abcdef\r\nghijkl\r\nmnopqrstuvwxyz\r\nABCDEFGHIJK" - test_stream = io.BytesIO(data) - lines = list(wsgi.make_line_iter(test_stream, limit=len(data), buffer_size=16)) - assert lines == [ - b"abcdef\r\n", - b"ghijkl\r\n", - b"mnopqrstuvwxyz\r\n", - b"ABCDEFGHIJK", - ] - - data = b"abc\r\nThis line is broken by the buffer length.\r\nFoo bar baz" - test_stream = io.BytesIO(data) - lines = list(wsgi.make_line_iter(test_stream, limit=len(data), buffer_size=24)) - assert lines == [ - b"abc\r\n", - b"This line is broken by the buffer length.\r\n", - b"Foo bar baz", - ] - - -def test_multi_part_line_breaks_problematic(): - data = "abc\rdef\r\nghi" - for _ in range(1, 10): - test_stream = io.StringIO(data) - lines = list(wsgi.make_line_iter(test_stream, limit=len(data), buffer_size=4)) - assert lines == ["abc\r", "def\r\n", "ghi"] - - -def test_iter_functions_support_iterators(): - data = ["abcdef\r\nghi", "jkl\r\nmnopqrstuvwxyz\r", "\nABCDEFGHIJK"] - lines = list(wsgi.make_line_iter(data)) - assert lines == ["abcdef\r\n", "ghijkl\r\n", "mnopqrstuvwxyz\r\n", "ABCDEFGHIJK"] - - -def test_make_chunk_iter(): - data = ["abcdefXghi", "jklXmnopqrstuvwxyzX", "ABCDEFGHIJK"] - rv = list(wsgi.make_chunk_iter(data, "X")) - assert rv == ["abcdef", "ghijkl", "mnopqrstuvwxyz", "ABCDEFGHIJK"] - - data = "abcdefXghijklXmnopqrstuvwxyzXABCDEFGHIJK" - test_stream = io.StringIO(data) - rv = list(wsgi.make_chunk_iter(test_stream, "X", limit=len(data), buffer_size=4)) - assert rv == ["abcdef", "ghijkl", "mnopqrstuvwxyz", "ABCDEFGHIJK"] - - -def test_make_chunk_iter_bytes(): - data = [b"abcdefXghi", b"jklXmnopqrstuvwxyzX", b"ABCDEFGHIJK"] - rv = list(wsgi.make_chunk_iter(data, "X")) - assert rv == [b"abcdef", b"ghijkl", b"mnopqrstuvwxyz", b"ABCDEFGHIJK"] - - data = b"abcdefXghijklXmnopqrstuvwxyzXABCDEFGHIJK" - test_stream = io.BytesIO(data) - rv = list(wsgi.make_chunk_iter(test_stream, "X", limit=len(data), buffer_size=4)) - assert rv == [b"abcdef", b"ghijkl", b"mnopqrstuvwxyz", b"ABCDEFGHIJK"] - - data = b"abcdefXghijklXmnopqrstuvwxyzXABCDEFGHIJK" - test_stream = io.BytesIO(data) - rv = list( - wsgi.make_chunk_iter( - test_stream, "X", limit=len(data), buffer_size=4, cap_at_buffer=True - ) - ) - assert rv == [ - b"abcd", - b"ef", - b"ghij", - b"kl", - b"mnop", - b"qrst", - b"uvwx", - b"yz", - b"ABCD", - b"EFGH", - b"IJK", - ] - - -def test_lines_longer_buffer_size(): - data = "1234567890\n1234567890\n" - for bufsize in range(1, 15): - lines = list( - wsgi.make_line_iter(io.StringIO(data), limit=len(data), buffer_size=bufsize) - ) - assert lines == ["1234567890\n", "1234567890\n"] - - -def test_lines_longer_buffer_size_cap(): - data = "1234567890\n1234567890\n" - for bufsize in range(1, 15): - lines = list( - wsgi.make_line_iter( - io.StringIO(data), - limit=len(data), - buffer_size=bufsize, - cap_at_buffer=True, - ) - ) - assert len(lines[0]) == bufsize or lines[0].endswith("\n") - - def test_range_wrapper(): response = Response(b"Hello World") range_wrapper = _RangeWrapper(response.response, 6, 4) diff --git a/tox.ini b/tox.ini index 056ca0d..eca667f 100644 --- a/tox.ini +++ b/tox.ini @@ -1,19 +1,22 @@ [tox] envlist = - py3{11,10,9,8,7},pypy3{8,7} + py3{12,11,10,9,8} + pypy310 style typing docs skip_missing_interpreters = true [testenv] +package = wheel +wheel_build_env = .pkg deps = -r requirements/tests.txt commands = pytest -v --tb=short --basetemp={envtmpdir} {posargs} [testenv:style] deps = pre-commit skip_install = true -commands = pre-commit run --all-files --show-diff-on-failure +commands = pre-commit run --all-files [testenv:typing] deps = -r requirements/typing.txt