mirror of https://github.com/python/cpython.git
584 lines
15 KiB
C
584 lines
15 KiB
C
// typing.Union -- used to represent e.g. Union[int, str], int | str
|
|
#include "Python.h"
|
|
#include "pycore_object.h" // _PyObject_GC_TRACK/UNTRACK
|
|
#include "pycore_typevarobject.h" // _PyTypeAlias_Type, _Py_typing_type_repr
|
|
#include "pycore_unicodeobject.h" // _PyUnicode_EqualToASCIIString
|
|
#include "pycore_unionobject.h"
|
|
|
|
|
|
typedef struct {
|
|
PyObject_HEAD
|
|
PyObject *args; // all args (tuple)
|
|
PyObject *hashable_args; // frozenset or NULL
|
|
PyObject *unhashable_args; // tuple or NULL
|
|
PyObject *parameters;
|
|
PyObject *weakreflist;
|
|
} unionobject;
|
|
|
|
static void
|
|
unionobject_dealloc(PyObject *self)
|
|
{
|
|
unionobject *alias = (unionobject *)self;
|
|
|
|
_PyObject_GC_UNTRACK(self);
|
|
if (alias->weakreflist != NULL) {
|
|
PyObject_ClearWeakRefs((PyObject *)alias);
|
|
}
|
|
|
|
Py_XDECREF(alias->args);
|
|
Py_XDECREF(alias->hashable_args);
|
|
Py_XDECREF(alias->unhashable_args);
|
|
Py_XDECREF(alias->parameters);
|
|
Py_TYPE(self)->tp_free(self);
|
|
}
|
|
|
|
static int
|
|
union_traverse(PyObject *self, visitproc visit, void *arg)
|
|
{
|
|
unionobject *alias = (unionobject *)self;
|
|
Py_VISIT(alias->args);
|
|
Py_VISIT(alias->hashable_args);
|
|
Py_VISIT(alias->unhashable_args);
|
|
Py_VISIT(alias->parameters);
|
|
return 0;
|
|
}
|
|
|
|
static Py_hash_t
|
|
union_hash(PyObject *self)
|
|
{
|
|
unionobject *alias = (unionobject *)self;
|
|
// If there are any unhashable args, treat this union as unhashable.
|
|
// Otherwise, two unions might compare equal but have different hashes.
|
|
if (alias->unhashable_args) {
|
|
// Attempt to get an error from one of the values.
|
|
assert(PyTuple_CheckExact(alias->unhashable_args));
|
|
Py_ssize_t n = PyTuple_GET_SIZE(alias->unhashable_args);
|
|
for (Py_ssize_t i = 0; i < n; i++) {
|
|
PyObject *arg = PyTuple_GET_ITEM(alias->unhashable_args, i);
|
|
Py_hash_t hash = PyObject_Hash(arg);
|
|
if (hash == -1) {
|
|
return -1;
|
|
}
|
|
}
|
|
// The unhashable values somehow became hashable again. Still raise
|
|
// an error.
|
|
PyErr_Format(PyExc_TypeError, "union contains %d unhashable elements", n);
|
|
return -1;
|
|
}
|
|
return PyObject_Hash(alias->hashable_args);
|
|
}
|
|
|
|
static int
|
|
unions_equal(unionobject *a, unionobject *b)
|
|
{
|
|
int result = PyObject_RichCompareBool(a->hashable_args, b->hashable_args, Py_EQ);
|
|
if (result == -1) {
|
|
return -1;
|
|
}
|
|
if (result == 0) {
|
|
return 0;
|
|
}
|
|
if (a->unhashable_args && b->unhashable_args) {
|
|
Py_ssize_t n = PyTuple_GET_SIZE(a->unhashable_args);
|
|
if (n != PyTuple_GET_SIZE(b->unhashable_args)) {
|
|
return 0;
|
|
}
|
|
for (Py_ssize_t i = 0; i < n; i++) {
|
|
PyObject *arg_a = PyTuple_GET_ITEM(a->unhashable_args, i);
|
|
int result = PySequence_Contains(b->unhashable_args, arg_a);
|
|
if (result == -1) {
|
|
return -1;
|
|
}
|
|
if (!result) {
|
|
return 0;
|
|
}
|
|
}
|
|
for (Py_ssize_t i = 0; i < n; i++) {
|
|
PyObject *arg_b = PyTuple_GET_ITEM(b->unhashable_args, i);
|
|
int result = PySequence_Contains(a->unhashable_args, arg_b);
|
|
if (result == -1) {
|
|
return -1;
|
|
}
|
|
if (!result) {
|
|
return 0;
|
|
}
|
|
}
|
|
}
|
|
else if (a->unhashable_args || b->unhashable_args) {
|
|
return 0;
|
|
}
|
|
return 1;
|
|
}
|
|
|
|
static PyObject *
|
|
union_richcompare(PyObject *a, PyObject *b, int op)
|
|
{
|
|
if (!_PyUnion_Check(b) || (op != Py_EQ && op != Py_NE)) {
|
|
Py_RETURN_NOTIMPLEMENTED;
|
|
}
|
|
|
|
int equal = unions_equal((unionobject*)a, (unionobject*)b);
|
|
if (equal == -1) {
|
|
return NULL;
|
|
}
|
|
if (op == Py_EQ) {
|
|
return PyBool_FromLong(equal);
|
|
}
|
|
else {
|
|
return PyBool_FromLong(!equal);
|
|
}
|
|
}
|
|
|
|
typedef struct {
|
|
PyObject *args; // list
|
|
PyObject *hashable_args; // set
|
|
PyObject *unhashable_args; // list or NULL
|
|
bool is_checked; // whether to call type_check()
|
|
} unionbuilder;
|
|
|
|
static bool unionbuilder_add_tuple(unionbuilder *, PyObject *);
|
|
static PyObject *make_union(unionbuilder *);
|
|
static PyObject *type_check(PyObject *, const char *);
|
|
|
|
static bool
|
|
unionbuilder_init(unionbuilder *ub, bool is_checked)
|
|
{
|
|
ub->args = PyList_New(0);
|
|
if (ub->args == NULL) {
|
|
return false;
|
|
}
|
|
ub->hashable_args = PySet_New(NULL);
|
|
if (ub->hashable_args == NULL) {
|
|
Py_DECREF(ub->args);
|
|
return false;
|
|
}
|
|
ub->unhashable_args = NULL;
|
|
ub->is_checked = is_checked;
|
|
return true;
|
|
}
|
|
|
|
static void
|
|
unionbuilder_finalize(unionbuilder *ub)
|
|
{
|
|
Py_DECREF(ub->args);
|
|
Py_DECREF(ub->hashable_args);
|
|
Py_XDECREF(ub->unhashable_args);
|
|
}
|
|
|
|
static bool
|
|
unionbuilder_add_single_unchecked(unionbuilder *ub, PyObject *arg)
|
|
{
|
|
Py_hash_t hash = PyObject_Hash(arg);
|
|
if (hash == -1) {
|
|
PyErr_Clear();
|
|
if (ub->unhashable_args == NULL) {
|
|
ub->unhashable_args = PyList_New(0);
|
|
if (ub->unhashable_args == NULL) {
|
|
return false;
|
|
}
|
|
}
|
|
else {
|
|
int contains = PySequence_Contains(ub->unhashable_args, arg);
|
|
if (contains < 0) {
|
|
return false;
|
|
}
|
|
if (contains == 1) {
|
|
return true;
|
|
}
|
|
}
|
|
if (PyList_Append(ub->unhashable_args, arg) < 0) {
|
|
return false;
|
|
}
|
|
}
|
|
else {
|
|
int contains = PySet_Contains(ub->hashable_args, arg);
|
|
if (contains < 0) {
|
|
return false;
|
|
}
|
|
if (contains == 1) {
|
|
return true;
|
|
}
|
|
if (PySet_Add(ub->hashable_args, arg) < 0) {
|
|
return false;
|
|
}
|
|
}
|
|
return PyList_Append(ub->args, arg) == 0;
|
|
}
|
|
|
|
static bool
|
|
unionbuilder_add_single(unionbuilder *ub, PyObject *arg)
|
|
{
|
|
if (Py_IsNone(arg)) {
|
|
arg = (PyObject *)&_PyNone_Type; // immortal, so no refcounting needed
|
|
}
|
|
else if (_PyUnion_Check(arg)) {
|
|
PyObject *args = ((unionobject *)arg)->args;
|
|
return unionbuilder_add_tuple(ub, args);
|
|
}
|
|
if (ub->is_checked) {
|
|
PyObject *type = type_check(arg, "Union[arg, ...]: each arg must be a type.");
|
|
if (type == NULL) {
|
|
return false;
|
|
}
|
|
bool result = unionbuilder_add_single_unchecked(ub, type);
|
|
Py_DECREF(type);
|
|
return result;
|
|
}
|
|
else {
|
|
return unionbuilder_add_single_unchecked(ub, arg);
|
|
}
|
|
}
|
|
|
|
static bool
|
|
unionbuilder_add_tuple(unionbuilder *ub, PyObject *tuple)
|
|
{
|
|
Py_ssize_t n = PyTuple_GET_SIZE(tuple);
|
|
for (Py_ssize_t i = 0; i < n; i++) {
|
|
if (!unionbuilder_add_single(ub, PyTuple_GET_ITEM(tuple, i))) {
|
|
return false;
|
|
}
|
|
}
|
|
return true;
|
|
}
|
|
|
|
static int
|
|
is_unionable(PyObject *obj)
|
|
{
|
|
if (obj == Py_None ||
|
|
PyType_Check(obj) ||
|
|
_PyGenericAlias_Check(obj) ||
|
|
_PyUnion_Check(obj) ||
|
|
Py_IS_TYPE(obj, &_PyTypeAlias_Type)) {
|
|
return 1;
|
|
}
|
|
return 0;
|
|
}
|
|
|
|
PyObject *
|
|
_Py_union_type_or(PyObject* self, PyObject* other)
|
|
{
|
|
if (!is_unionable(self) || !is_unionable(other)) {
|
|
Py_RETURN_NOTIMPLEMENTED;
|
|
}
|
|
|
|
unionbuilder ub;
|
|
// unchecked because we already checked is_unionable()
|
|
if (!unionbuilder_init(&ub, false)) {
|
|
return NULL;
|
|
}
|
|
if (!unionbuilder_add_single(&ub, self) ||
|
|
!unionbuilder_add_single(&ub, other)) {
|
|
unionbuilder_finalize(&ub);
|
|
return NULL;
|
|
}
|
|
|
|
PyObject *new_union = make_union(&ub);
|
|
return new_union;
|
|
}
|
|
|
|
static PyObject *
|
|
union_repr(PyObject *self)
|
|
{
|
|
unionobject *alias = (unionobject *)self;
|
|
Py_ssize_t len = PyTuple_GET_SIZE(alias->args);
|
|
|
|
// Shortest type name "int" (3 chars) + " | " (3 chars) separator
|
|
Py_ssize_t estimate = (len <= PY_SSIZE_T_MAX / 6) ? len * 6 : len;
|
|
PyUnicodeWriter *writer = PyUnicodeWriter_Create(estimate);
|
|
if (writer == NULL) {
|
|
return NULL;
|
|
}
|
|
|
|
for (Py_ssize_t i = 0; i < len; i++) {
|
|
if (i > 0 && PyUnicodeWriter_WriteASCII(writer, " | ", 3) < 0) {
|
|
goto error;
|
|
}
|
|
PyObject *p = PyTuple_GET_ITEM(alias->args, i);
|
|
if (_Py_typing_type_repr(writer, p) < 0) {
|
|
goto error;
|
|
}
|
|
}
|
|
|
|
#if 0
|
|
PyUnicodeWriter_WriteASCII(writer, "|args=", 6);
|
|
PyUnicodeWriter_WriteRepr(writer, alias->args);
|
|
PyUnicodeWriter_WriteASCII(writer, "|h=", 3);
|
|
PyUnicodeWriter_WriteRepr(writer, alias->hashable_args);
|
|
if (alias->unhashable_args) {
|
|
PyUnicodeWriter_WriteASCII(writer, "|u=", 3);
|
|
PyUnicodeWriter_WriteRepr(writer, alias->unhashable_args);
|
|
}
|
|
#endif
|
|
|
|
return PyUnicodeWriter_Finish(writer);
|
|
|
|
error:
|
|
PyUnicodeWriter_Discard(writer);
|
|
return NULL;
|
|
}
|
|
|
|
static PyMemberDef union_members[] = {
|
|
{"__args__", _Py_T_OBJECT, offsetof(unionobject, args), Py_READONLY},
|
|
{0}
|
|
};
|
|
|
|
// Populate __parameters__ if needed.
|
|
static int
|
|
union_init_parameters(unionobject *alias)
|
|
{
|
|
int result = 0;
|
|
Py_BEGIN_CRITICAL_SECTION(alias);
|
|
if (alias->parameters == NULL) {
|
|
alias->parameters = _Py_make_parameters(alias->args);
|
|
if (alias->parameters == NULL) {
|
|
result = -1;
|
|
}
|
|
}
|
|
Py_END_CRITICAL_SECTION();
|
|
return result;
|
|
}
|
|
|
|
static PyObject *
|
|
union_getitem(PyObject *self, PyObject *item)
|
|
{
|
|
unionobject *alias = (unionobject *)self;
|
|
if (union_init_parameters(alias) < 0) {
|
|
return NULL;
|
|
}
|
|
|
|
PyObject *newargs = _Py_subs_parameters(self, alias->args, alias->parameters, item);
|
|
if (newargs == NULL) {
|
|
return NULL;
|
|
}
|
|
|
|
PyObject *res = _Py_union_from_tuple(newargs);
|
|
Py_DECREF(newargs);
|
|
return res;
|
|
}
|
|
|
|
static PyMappingMethods union_as_mapping = {
|
|
.mp_subscript = union_getitem,
|
|
};
|
|
|
|
static PyObject *
|
|
union_parameters(PyObject *self, void *Py_UNUSED(unused))
|
|
{
|
|
unionobject *alias = (unionobject *)self;
|
|
if (union_init_parameters(alias) < 0) {
|
|
return NULL;
|
|
}
|
|
return Py_NewRef(alias->parameters);
|
|
}
|
|
|
|
static PyObject *
|
|
union_name(PyObject *Py_UNUSED(self), void *Py_UNUSED(ignored))
|
|
{
|
|
return PyUnicode_FromString("Union");
|
|
}
|
|
|
|
static PyObject *
|
|
union_origin(PyObject *Py_UNUSED(self), void *Py_UNUSED(ignored))
|
|
{
|
|
return Py_NewRef(&_PyUnion_Type);
|
|
}
|
|
|
|
static PyGetSetDef union_properties[] = {
|
|
{"__name__", union_name, NULL,
|
|
PyDoc_STR("Name of the type"), NULL},
|
|
{"__qualname__", union_name, NULL,
|
|
PyDoc_STR("Qualified name of the type"), NULL},
|
|
{"__origin__", union_origin, NULL,
|
|
PyDoc_STR("Always returns the type"), NULL},
|
|
{"__parameters__", union_parameters, NULL,
|
|
PyDoc_STR("Type variables in the types.UnionType."), NULL},
|
|
{0}
|
|
};
|
|
|
|
static PyNumberMethods union_as_number = {
|
|
.nb_or = _Py_union_type_or, // Add __or__ function
|
|
};
|
|
|
|
static const char* const cls_attrs[] = {
|
|
"__module__", // Required for compatibility with typing module
|
|
NULL,
|
|
};
|
|
|
|
static PyObject *
|
|
union_getattro(PyObject *self, PyObject *name)
|
|
{
|
|
unionobject *alias = (unionobject *)self;
|
|
if (PyUnicode_Check(name)) {
|
|
for (const char * const *p = cls_attrs; ; p++) {
|
|
if (*p == NULL) {
|
|
break;
|
|
}
|
|
if (_PyUnicode_EqualToASCIIString(name, *p)) {
|
|
return PyObject_GetAttr((PyObject *) Py_TYPE(alias), name);
|
|
}
|
|
}
|
|
}
|
|
return PyObject_GenericGetAttr(self, name);
|
|
}
|
|
|
|
PyObject *
|
|
_Py_union_args(PyObject *self)
|
|
{
|
|
assert(_PyUnion_Check(self));
|
|
return ((unionobject *) self)->args;
|
|
}
|
|
|
|
static PyObject *
|
|
call_typing_func_object(const char *name, PyObject **args, size_t nargs)
|
|
{
|
|
PyObject *typing = PyImport_ImportModule("typing");
|
|
if (typing == NULL) {
|
|
return NULL;
|
|
}
|
|
PyObject *func = PyObject_GetAttrString(typing, name);
|
|
if (func == NULL) {
|
|
Py_DECREF(typing);
|
|
return NULL;
|
|
}
|
|
PyObject *result = PyObject_Vectorcall(func, args, nargs, NULL);
|
|
Py_DECREF(func);
|
|
Py_DECREF(typing);
|
|
return result;
|
|
}
|
|
|
|
static PyObject *
|
|
type_check(PyObject *arg, const char *msg)
|
|
{
|
|
if (Py_IsNone(arg)) {
|
|
// NoneType is immortal, so don't need an INCREF
|
|
return (PyObject *)Py_TYPE(arg);
|
|
}
|
|
// Fast path to avoid calling into typing.py
|
|
if (is_unionable(arg)) {
|
|
return Py_NewRef(arg);
|
|
}
|
|
PyObject *message_str = PyUnicode_FromString(msg);
|
|
if (message_str == NULL) {
|
|
return NULL;
|
|
}
|
|
PyObject *args[2] = {arg, message_str};
|
|
PyObject *result = call_typing_func_object("_type_check", args, 2);
|
|
Py_DECREF(message_str);
|
|
return result;
|
|
}
|
|
|
|
PyObject *
|
|
_Py_union_from_tuple(PyObject *args)
|
|
{
|
|
unionbuilder ub;
|
|
if (!unionbuilder_init(&ub, true)) {
|
|
return NULL;
|
|
}
|
|
if (PyTuple_CheckExact(args)) {
|
|
if (!unionbuilder_add_tuple(&ub, args)) {
|
|
return NULL;
|
|
}
|
|
}
|
|
else {
|
|
if (!unionbuilder_add_single(&ub, args)) {
|
|
return NULL;
|
|
}
|
|
}
|
|
return make_union(&ub);
|
|
}
|
|
|
|
static PyObject *
|
|
union_class_getitem(PyObject *cls, PyObject *args)
|
|
{
|
|
return _Py_union_from_tuple(args);
|
|
}
|
|
|
|
static PyObject *
|
|
union_mro_entries(PyObject *self, PyObject *args)
|
|
{
|
|
return PyErr_Format(PyExc_TypeError,
|
|
"Cannot subclass %R", self);
|
|
}
|
|
|
|
static PyMethodDef union_methods[] = {
|
|
{"__mro_entries__", union_mro_entries, METH_O},
|
|
{"__class_getitem__", union_class_getitem, METH_O|METH_CLASS, PyDoc_STR("See PEP 585")},
|
|
{0}
|
|
};
|
|
|
|
PyTypeObject _PyUnion_Type = {
|
|
PyVarObject_HEAD_INIT(&PyType_Type, 0)
|
|
.tp_name = "typing.Union",
|
|
.tp_doc = PyDoc_STR("Represent a union type\n"
|
|
"\n"
|
|
"E.g. for int | str"),
|
|
.tp_basicsize = sizeof(unionobject),
|
|
.tp_dealloc = unionobject_dealloc,
|
|
.tp_alloc = PyType_GenericAlloc,
|
|
.tp_free = PyObject_GC_Del,
|
|
.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC,
|
|
.tp_traverse = union_traverse,
|
|
.tp_hash = union_hash,
|
|
.tp_getattro = union_getattro,
|
|
.tp_members = union_members,
|
|
.tp_methods = union_methods,
|
|
.tp_richcompare = union_richcompare,
|
|
.tp_as_mapping = &union_as_mapping,
|
|
.tp_as_number = &union_as_number,
|
|
.tp_repr = union_repr,
|
|
.tp_getset = union_properties,
|
|
.tp_weaklistoffset = offsetof(unionobject, weakreflist),
|
|
};
|
|
|
|
static PyObject *
|
|
make_union(unionbuilder *ub)
|
|
{
|
|
Py_ssize_t n = PyList_GET_SIZE(ub->args);
|
|
if (n == 0) {
|
|
PyErr_SetString(PyExc_TypeError, "Cannot take a Union of no types.");
|
|
unionbuilder_finalize(ub);
|
|
return NULL;
|
|
}
|
|
if (n == 1) {
|
|
PyObject *result = PyList_GET_ITEM(ub->args, 0);
|
|
Py_INCREF(result);
|
|
unionbuilder_finalize(ub);
|
|
return result;
|
|
}
|
|
|
|
PyObject *args = NULL, *hashable_args = NULL, *unhashable_args = NULL;
|
|
args = PyList_AsTuple(ub->args);
|
|
if (args == NULL) {
|
|
goto error;
|
|
}
|
|
hashable_args = PyFrozenSet_New(ub->hashable_args);
|
|
if (hashable_args == NULL) {
|
|
goto error;
|
|
}
|
|
if (ub->unhashable_args != NULL) {
|
|
unhashable_args = PyList_AsTuple(ub->unhashable_args);
|
|
if (unhashable_args == NULL) {
|
|
goto error;
|
|
}
|
|
}
|
|
|
|
unionobject *result = PyObject_GC_New(unionobject, &_PyUnion_Type);
|
|
if (result == NULL) {
|
|
goto error;
|
|
}
|
|
unionbuilder_finalize(ub);
|
|
|
|
result->parameters = NULL;
|
|
result->args = args;
|
|
result->hashable_args = hashable_args;
|
|
result->unhashable_args = unhashable_args;
|
|
result->weakreflist = NULL;
|
|
_PyObject_GC_TRACK(result);
|
|
return (PyObject*)result;
|
|
error:
|
|
Py_XDECREF(args);
|
|
Py_XDECREF(hashable_args);
|
|
Py_XDECREF(unhashable_args);
|
|
unionbuilder_finalize(ub);
|
|
return NULL;
|
|
}
|