(root)/
Python-3.12.0/
Objects/
unionobject.c
       1  // types.UnionType -- used to represent e.g. Union[int, str], int | str
       2  #include "Python.h"
       3  #include "pycore_object.h"  // _PyObject_GC_TRACK/UNTRACK
       4  #include "pycore_typevarobject.h"  // _PyTypeAlias_Type
       5  #include "pycore_unionobject.h"
       6  #include "structmember.h"
       7  
       8  
       9  static PyObject *make_union(PyObject *);
      10  
      11  
      12  typedef struct {
      13      PyObject_HEAD
      14      PyObject *args;
      15      PyObject *parameters;
      16  } unionobject;
      17  
      18  static void
      19  unionobject_dealloc(PyObject *self)
      20  {
      21      unionobject *alias = (unionobject *)self;
      22  
      23      _PyObject_GC_UNTRACK(self);
      24  
      25      Py_XDECREF(alias->args);
      26      Py_XDECREF(alias->parameters);
      27      Py_TYPE(self)->tp_free(self);
      28  }
      29  
      30  static int
      31  union_traverse(PyObject *self, visitproc visit, void *arg)
      32  {
      33      unionobject *alias = (unionobject *)self;
      34      Py_VISIT(alias->args);
      35      Py_VISIT(alias->parameters);
      36      return 0;
      37  }
      38  
      39  static Py_hash_t
      40  union_hash(PyObject *self)
      41  {
      42      unionobject *alias = (unionobject *)self;
      43      PyObject *args = PyFrozenSet_New(alias->args);
      44      if (args == NULL) {
      45          return (Py_hash_t)-1;
      46      }
      47      Py_hash_t hash = PyObject_Hash(args);
      48      Py_DECREF(args);
      49      return hash;
      50  }
      51  
      52  static PyObject *
      53  union_richcompare(PyObject *a, PyObject *b, int op)
      54  {
      55      if (!_PyUnion_Check(b) || (op != Py_EQ && op != Py_NE)) {
      56          Py_RETURN_NOTIMPLEMENTED;
      57      }
      58  
      59      PyObject *a_set = PySet_New(((unionobject*)a)->args);
      60      if (a_set == NULL) {
      61          return NULL;
      62      }
      63      PyObject *b_set = PySet_New(((unionobject*)b)->args);
      64      if (b_set == NULL) {
      65          Py_DECREF(a_set);
      66          return NULL;
      67      }
      68      PyObject *result = PyObject_RichCompare(a_set, b_set, op);
      69      Py_DECREF(b_set);
      70      Py_DECREF(a_set);
      71      return result;
      72  }
      73  
      74  static int
      75  is_same(PyObject *left, PyObject *right)
      76  {
      77      int is_ga = _PyGenericAlias_Check(left) && _PyGenericAlias_Check(right);
      78      return is_ga ? PyObject_RichCompareBool(left, right, Py_EQ) : left == right;
      79  }
      80  
      81  static int
      82  contains(PyObject **items, Py_ssize_t size, PyObject *obj)
      83  {
      84      for (int i = 0; i < size; i++) {
      85          int is_duplicate = is_same(items[i], obj);
      86          if (is_duplicate) {  // -1 or 1
      87              return is_duplicate;
      88          }
      89      }
      90      return 0;
      91  }
      92  
      93  static PyObject *
      94  merge(PyObject **items1, Py_ssize_t size1,
      95        PyObject **items2, Py_ssize_t size2)
      96  {
      97      PyObject *tuple = NULL;
      98      Py_ssize_t pos = 0;
      99  
     100      for (int i = 0; i < size2; i++) {
     101          PyObject *arg = items2[i];
     102          int is_duplicate = contains(items1, size1, arg);
     103          if (is_duplicate < 0) {
     104              Py_XDECREF(tuple);
     105              return NULL;
     106          }
     107          if (is_duplicate) {
     108              continue;
     109          }
     110  
     111          if (tuple == NULL) {
     112              tuple = PyTuple_New(size1 + size2 - i);
     113              if (tuple == NULL) {
     114                  return NULL;
     115              }
     116              for (; pos < size1; pos++) {
     117                  PyObject *a = items1[pos];
     118                  PyTuple_SET_ITEM(tuple, pos, Py_NewRef(a));
     119              }
     120          }
     121          PyTuple_SET_ITEM(tuple, pos, Py_NewRef(arg));
     122          pos++;
     123      }
     124  
     125      if (tuple) {
     126          (void) _PyTuple_Resize(&tuple, pos);
     127      }
     128      return tuple;
     129  }
     130  
     131  static PyObject **
     132  get_types(PyObject **obj, Py_ssize_t *size)
     133  {
     134      if (*obj == Py_None) {
     135          *obj = (PyObject *)&_PyNone_Type;
     136      }
     137      if (_PyUnion_Check(*obj)) {
     138          PyObject *args = ((unionobject *) *obj)->args;
     139          *size = PyTuple_GET_SIZE(args);
     140          return &PyTuple_GET_ITEM(args, 0);
     141      }
     142      else {
     143          *size = 1;
     144          return obj;
     145      }
     146  }
     147  
     148  static int
     149  is_unionable(PyObject *obj)
     150  {
     151      if (obj == Py_None ||
     152          PyType_Check(obj) ||
     153          _PyGenericAlias_Check(obj) ||
     154          _PyUnion_Check(obj) ||
     155          Py_IS_TYPE(obj, &_PyTypeAlias_Type)) {
     156          return 1;
     157      }
     158      return 0;
     159  }
     160  
     161  PyObject *
     162  _Py_union_type_or(PyObject* self, PyObject* other)
     163  {
     164      if (!is_unionable(self) || !is_unionable(other)) {
     165          Py_RETURN_NOTIMPLEMENTED;
     166      }
     167  
     168      Py_ssize_t size1, size2;
     169      PyObject **items1 = get_types(&self, &size1);
     170      PyObject **items2 = get_types(&other, &size2);
     171      PyObject *tuple = merge(items1, size1, items2, size2);
     172      if (tuple == NULL) {
     173          if (PyErr_Occurred()) {
     174              return NULL;
     175          }
     176          return Py_NewRef(self);
     177      }
     178  
     179      PyObject *new_union = make_union(tuple);
     180      Py_DECREF(tuple);
     181      return new_union;
     182  }
     183  
     184  static int
     185  union_repr_item(_PyUnicodeWriter *writer, PyObject *p)
     186  {
     187      PyObject *qualname = NULL;
     188      PyObject *module = NULL;
     189      PyObject *tmp;
     190      PyObject *r = NULL;
     191      int err;
     192  
     193      if (p == (PyObject *)&_PyNone_Type) {
     194          return _PyUnicodeWriter_WriteASCIIString(writer, "None", 4);
     195      }
     196  
     197      if (_PyObject_LookupAttr(p, &_Py_ID(__origin__), &tmp) < 0) {
     198          goto exit;
     199      }
     200  
     201      if (tmp) {
     202          Py_DECREF(tmp);
     203          if (_PyObject_LookupAttr(p, &_Py_ID(__args__), &tmp) < 0) {
     204              goto exit;
     205          }
     206          if (tmp) {
     207              // It looks like a GenericAlias
     208              Py_DECREF(tmp);
     209              goto use_repr;
     210          }
     211      }
     212  
     213      if (_PyObject_LookupAttr(p, &_Py_ID(__qualname__), &qualname) < 0) {
     214          goto exit;
     215      }
     216      if (qualname == NULL) {
     217          goto use_repr;
     218      }
     219      if (_PyObject_LookupAttr(p, &_Py_ID(__module__), &module) < 0) {
     220          goto exit;
     221      }
     222      if (module == NULL || module == Py_None) {
     223          goto use_repr;
     224      }
     225  
     226      // Looks like a class
     227      if (PyUnicode_Check(module) &&
     228          _PyUnicode_EqualToASCIIString(module, "builtins"))
     229      {
     230          // builtins don't need a module name
     231          r = PyObject_Str(qualname);
     232          goto exit;
     233      }
     234      else {
     235          r = PyUnicode_FromFormat("%S.%S", module, qualname);
     236          goto exit;
     237      }
     238  
     239  use_repr:
     240      r = PyObject_Repr(p);
     241  exit:
     242      Py_XDECREF(qualname);
     243      Py_XDECREF(module);
     244      if (r == NULL) {
     245          return -1;
     246      }
     247      err = _PyUnicodeWriter_WriteStr(writer, r);
     248      Py_DECREF(r);
     249      return err;
     250  }
     251  
     252  static PyObject *
     253  union_repr(PyObject *self)
     254  {
     255      unionobject *alias = (unionobject *)self;
     256      Py_ssize_t len = PyTuple_GET_SIZE(alias->args);
     257  
     258      _PyUnicodeWriter writer;
     259      _PyUnicodeWriter_Init(&writer);
     260       for (Py_ssize_t i = 0; i < len; i++) {
     261          if (i > 0 && _PyUnicodeWriter_WriteASCIIString(&writer, " | ", 3) < 0) {
     262              goto error;
     263          }
     264          PyObject *p = PyTuple_GET_ITEM(alias->args, i);
     265          if (union_repr_item(&writer, p) < 0) {
     266              goto error;
     267          }
     268      }
     269      return _PyUnicodeWriter_Finish(&writer);
     270  error:
     271      _PyUnicodeWriter_Dealloc(&writer);
     272      return NULL;
     273  }
     274  
     275  static PyMemberDef union_members[] = {
     276          {"__args__", T_OBJECT, offsetof(unionobject, args), READONLY},
     277          {0}
     278  };
     279  
     280  static PyObject *
     281  union_getitem(PyObject *self, PyObject *item)
     282  {
     283      unionobject *alias = (unionobject *)self;
     284      // Populate __parameters__ if needed.
     285      if (alias->parameters == NULL) {
     286          alias->parameters = _Py_make_parameters(alias->args);
     287          if (alias->parameters == NULL) {
     288              return NULL;
     289          }
     290      }
     291  
     292      PyObject *newargs = _Py_subs_parameters(self, alias->args, alias->parameters, item);
     293      if (newargs == NULL) {
     294          return NULL;
     295      }
     296  
     297      PyObject *res;
     298      Py_ssize_t nargs = PyTuple_GET_SIZE(newargs);
     299      if (nargs == 0) {
     300          res = make_union(newargs);
     301      }
     302      else {
     303          res = Py_NewRef(PyTuple_GET_ITEM(newargs, 0));
     304          for (Py_ssize_t iarg = 1; iarg < nargs; iarg++) {
     305              PyObject *arg = PyTuple_GET_ITEM(newargs, iarg);
     306              Py_SETREF(res, PyNumber_Or(res, arg));
     307              if (res == NULL) {
     308                  break;
     309              }
     310          }
     311      }
     312      Py_DECREF(newargs);
     313      return res;
     314  }
     315  
     316  static PyMappingMethods union_as_mapping = {
     317      .mp_subscript = union_getitem,
     318  };
     319  
     320  static PyObject *
     321  union_parameters(PyObject *self, void *Py_UNUSED(unused))
     322  {
     323      unionobject *alias = (unionobject *)self;
     324      if (alias->parameters == NULL) {
     325          alias->parameters = _Py_make_parameters(alias->args);
     326          if (alias->parameters == NULL) {
     327              return NULL;
     328          }
     329      }
     330      return Py_NewRef(alias->parameters);
     331  }
     332  
     333  static PyGetSetDef union_properties[] = {
     334      {"__parameters__", union_parameters, (setter)NULL, "Type variables in the types.UnionType.", NULL},
     335      {0}
     336  };
     337  
     338  static PyNumberMethods union_as_number = {
     339          .nb_or = _Py_union_type_or, // Add __or__ function
     340  };
     341  
     342  static const char* const cls_attrs[] = {
     343          "__module__",  // Required for compatibility with typing module
     344          NULL,
     345  };
     346  
     347  static PyObject *
     348  union_getattro(PyObject *self, PyObject *name)
     349  {
     350      unionobject *alias = (unionobject *)self;
     351      if (PyUnicode_Check(name)) {
     352          for (const char * const *p = cls_attrs; ; p++) {
     353              if (*p == NULL) {
     354                  break;
     355              }
     356              if (_PyUnicode_EqualToASCIIString(name, *p)) {
     357                  return PyObject_GetAttr((PyObject *) Py_TYPE(alias), name);
     358              }
     359          }
     360      }
     361      return PyObject_GenericGetAttr(self, name);
     362  }
     363  
     364  PyObject *
     365  _Py_union_args(PyObject *self)
     366  {
     367      assert(_PyUnion_Check(self));
     368      return ((unionobject *) self)->args;
     369  }
     370  
     371  PyTypeObject _PyUnion_Type = {
     372      PyVarObject_HEAD_INIT(&PyType_Type, 0)
     373      .tp_name = "types.UnionType",
     374      .tp_doc = PyDoc_STR("Represent a PEP 604 union type\n"
     375                "\n"
     376                "E.g. for int | str"),
     377      .tp_basicsize = sizeof(unionobject),
     378      .tp_dealloc = unionobject_dealloc,
     379      .tp_alloc = PyType_GenericAlloc,
     380      .tp_free = PyObject_GC_Del,
     381      .tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC,
     382      .tp_traverse = union_traverse,
     383      .tp_hash = union_hash,
     384      .tp_getattro = union_getattro,
     385      .tp_members = union_members,
     386      .tp_richcompare = union_richcompare,
     387      .tp_as_mapping = &union_as_mapping,
     388      .tp_as_number = &union_as_number,
     389      .tp_repr = union_repr,
     390      .tp_getset = union_properties,
     391  };
     392  
     393  static PyObject *
     394  make_union(PyObject *args)
     395  {
     396      assert(PyTuple_CheckExact(args));
     397  
     398      unionobject *result = PyObject_GC_New(unionobject, &_PyUnion_Type);
     399      if (result == NULL) {
     400          return NULL;
     401      }
     402  
     403      result->parameters = NULL;
     404      result->args = Py_NewRef(args);
     405      _PyObject_GC_TRACK(result);
     406      return (PyObject*)result;
     407  }