diff --git a/Lib/test/test_types.py b/Lib/test/test_types.py index 6ed74714dc57..99b011e37866 100644 --- a/Lib/test/test_types.py +++ b/Lib/test/test_types.py @@ -772,6 +772,36 @@ def test_union_parameter_chaining(self): self.assertEqual((list[T] | list[S])[int, T], list[int] | list[T]) self.assertEqual((list[T] | list[S])[int, int], list[int]) + def test_union_parameter_substitution(self): + def eq(actual, expected): + self.assertEqual(actual, expected) + self.assertIs(type(actual), type(expected)) + + T = typing.TypeVar('T') + S = typing.TypeVar('S') + NT = typing.NewType('NT', str) + x = int | T | bytes + + eq(x[str], int | str | bytes) + eq(x[list[int]], int | list[int] | bytes) + eq(x[typing.List], int | typing.List | bytes) + eq(x[typing.List[int]], int | typing.List[int] | bytes) + eq(x[typing.Hashable], int | typing.Hashable | bytes) + eq(x[collections.abc.Hashable], + int | collections.abc.Hashable | bytes) + eq(x[typing.Callable[[int], str]], + int | typing.Callable[[int], str] | bytes) + eq(x[collections.abc.Callable[[int], str]], + int | collections.abc.Callable[[int], str] | bytes) + eq(x[typing.Tuple[int, str]], int | typing.Tuple[int, str] | bytes) + eq(x[typing.Literal['none']], int | typing.Literal['none'] | bytes) + eq(x[str | list], int | str | list | bytes) + eq(x[typing.Union[str, list]], typing.Union[int, str, list, bytes]) + eq(x[str | int], int | str | bytes) + eq(x[typing.Union[str, int]], typing.Union[int, str, bytes]) + eq(x[NT], int | NT | bytes) + eq(x[S], int | S | bytes) + def test_union_parameter_substitution_errors(self): T = typing.TypeVar("T") x = int | T diff --git a/Misc/NEWS.d/next/Core and Builtins/2021-07-19-20-49-06.bpo-44653.WcqGyI.rst b/Misc/NEWS.d/next/Core and Builtins/2021-07-19-20-49-06.bpo-44653.WcqGyI.rst new file mode 100644 index 000000000000..8254d9bbad4a --- /dev/null +++ b/Misc/NEWS.d/next/Core and Builtins/2021-07-19-20-49-06.bpo-44653.WcqGyI.rst @@ -0,0 +1 @@ +Support :mod:`typing` types in parameter substitution in the union type. diff --git a/Objects/unionobject.c b/Objects/unionobject.c index c0c9a24bcc20..659346aac821 100644 --- a/Objects/unionobject.c +++ b/Objects/unionobject.c @@ -446,23 +446,22 @@ union_getitem(PyObject *self, PyObject *item) return NULL; } - // Check arguments are unionable. + PyObject *res; Py_ssize_t nargs = PyTuple_GET_SIZE(newargs); - for (Py_ssize_t iarg = 0; iarg < nargs; iarg++) { - PyObject *arg = PyTuple_GET_ITEM(newargs, iarg); - int is_arg_unionable = is_unionable(arg); - if (is_arg_unionable <= 0) { - Py_DECREF(newargs); - if (is_arg_unionable == 0) { - PyErr_Format(PyExc_TypeError, - "Each union argument must be a type, got %.100R", arg); + if (nargs == 0) { + res = make_union(newargs); + } + else { + res = PyTuple_GET_ITEM(newargs, 0); + Py_INCREF(res); + for (Py_ssize_t iarg = 1; iarg < nargs; iarg++) { + PyObject *arg = PyTuple_GET_ITEM(newargs, iarg); + Py_SETREF(res, PyNumber_Or(res, arg)); + if (res == NULL) { + break; } - return NULL; } } - - PyObject *res = make_union(newargs); - Py_DECREF(newargs); return res; }