(root)/
Python-3.11.7/
Objects/
unionobject.c
       1  // types.UnionType -- used to represent e.g. Union[int, str], int | str
       2  #include "Python.h"
       3  #include "pycore_object.h"  // _PyObject_GC_TRACK/UNTRACK
       4  #include "pycore_unionobject.h"
       5  #include "structmember.h"
       6  
       7  
       8  static PyObject *make_union(PyObject *);
       9  
      10  
      11  typedef struct {
      12      PyObject_HEAD
      13      PyObject *args;
      14      PyObject *parameters;
      15  } unionobject;
      16  
      17  static void
      18  unionobject_dealloc(PyObject *self)
      19  {
      20      unionobject *alias = (unionobject *)self;
      21  
      22      _PyObject_GC_UNTRACK(self);
      23  
      24      Py_XDECREF(alias->args);
      25      Py_XDECREF(alias->parameters);
      26      Py_TYPE(self)->tp_free(self);
      27  }
      28  
      29  static int
      30  union_traverse(PyObject *self, visitproc visit, void *arg)
      31  {
      32      unionobject *alias = (unionobject *)self;
      33      Py_VISIT(alias->args);
      34      Py_VISIT(alias->parameters);
      35      return 0;
      36  }
      37  
      38  static Py_hash_t
      39  union_hash(PyObject *self)
      40  {
      41      unionobject *alias = (unionobject *)self;
      42      PyObject *args = PyFrozenSet_New(alias->args);
      43      if (args == NULL) {
      44          return (Py_hash_t)-1;
      45      }
      46      Py_hash_t hash = PyObject_Hash(args);
      47      Py_DECREF(args);
      48      return hash;
      49  }
      50  
      51  static PyObject *
      52  union_richcompare(PyObject *a, PyObject *b, int op)
      53  {
      54      if (!_PyUnion_Check(b) || (op != Py_EQ && op != Py_NE)) {
      55          Py_RETURN_NOTIMPLEMENTED;
      56      }
      57  
      58      PyObject *a_set = PySet_New(((unionobject*)a)->args);
      59      if (a_set == NULL) {
      60          return NULL;
      61      }
      62      PyObject *b_set = PySet_New(((unionobject*)b)->args);
      63      if (b_set == NULL) {
      64          Py_DECREF(a_set);
      65          return NULL;
      66      }
      67      PyObject *result = PyObject_RichCompare(a_set, b_set, op);
      68      Py_DECREF(b_set);
      69      Py_DECREF(a_set);
      70      return result;
      71  }
      72  
      73  static int
      74  is_same(PyObject *left, PyObject *right)
      75  {
      76      int is_ga = _PyGenericAlias_Check(left) && _PyGenericAlias_Check(right);
      77      return is_ga ? PyObject_RichCompareBool(left, right, Py_EQ) : left == right;
      78  }
      79  
      80  static int
      81  contains(PyObject **items, Py_ssize_t size, PyObject *obj)
      82  {
      83      for (int i = 0; i < size; i++) {
      84          int is_duplicate = is_same(items[i], obj);
      85          if (is_duplicate) {  // -1 or 1
      86              return is_duplicate;
      87          }
      88      }
      89      return 0;
      90  }
      91  
      92  static PyObject *
      93  merge(PyObject **items1, Py_ssize_t size1,
      94        PyObject **items2, Py_ssize_t size2)
      95  {
      96      PyObject *tuple = NULL;
      97      Py_ssize_t pos = 0;
      98  
      99      for (int i = 0; i < size2; i++) {
     100          PyObject *arg = items2[i];
     101          int is_duplicate = contains(items1, size1, arg);
     102          if (is_duplicate < 0) {
     103              Py_XDECREF(tuple);
     104              return NULL;
     105          }
     106          if (is_duplicate) {
     107              continue;
     108          }
     109  
     110          if (tuple == NULL) {
     111              tuple = PyTuple_New(size1 + size2 - i);
     112              if (tuple == NULL) {
     113                  return NULL;
     114              }
     115              for (; pos < size1; pos++) {
     116                  PyObject *a = items1[pos];
     117                  Py_INCREF(a);
     118                  PyTuple_SET_ITEM(tuple, pos, a);
     119              }
     120          }
     121          Py_INCREF(arg);
     122          PyTuple_SET_ITEM(tuple, pos, arg);
     123          pos++;
     124      }
     125  
     126      if (tuple) {
     127          (void) _PyTuple_Resize(&tuple, pos);
     128      }
     129      return tuple;
     130  }
     131  
     132  static PyObject **
     133  get_types(PyObject **obj, Py_ssize_t *size)
     134  {
     135      if (*obj == Py_None) {
     136          *obj = (PyObject *)&_PyNone_Type;
     137      }
     138      if (_PyUnion_Check(*obj)) {
     139          PyObject *args = ((unionobject *) *obj)->args;
     140          *size = PyTuple_GET_SIZE(args);
     141          return &PyTuple_GET_ITEM(args, 0);
     142      }
     143      else {
     144          *size = 1;
     145          return obj;
     146      }
     147  }
     148  
     149  static int
     150  is_unionable(PyObject *obj)
     151  {
     152      return (obj == Py_None ||
     153          PyType_Check(obj) ||
     154          _PyGenericAlias_Check(obj) ||
     155          _PyUnion_Check(obj));
     156  }
     157  
     158  PyObject *
     159  _Py_union_type_or(PyObject* self, PyObject* other)
     160  {
     161      if (!is_unionable(self) || !is_unionable(other)) {
     162          Py_RETURN_NOTIMPLEMENTED;
     163      }
     164  
     165      Py_ssize_t size1, size2;
     166      PyObject **items1 = get_types(&self, &size1);
     167      PyObject **items2 = get_types(&other, &size2);
     168      PyObject *tuple = merge(items1, size1, items2, size2);
     169      if (tuple == NULL) {
     170          if (PyErr_Occurred()) {
     171              return NULL;
     172          }
     173          Py_INCREF(self);
     174          return self;
     175      }
     176  
     177      PyObject *new_union = make_union(tuple);
     178      Py_DECREF(tuple);
     179      return new_union;
     180  }
     181  
     182  static int
     183  union_repr_item(_PyUnicodeWriter *writer, PyObject *p)
     184  {
     185      PyObject *qualname = NULL;
     186      PyObject *module = NULL;
     187      PyObject *tmp;
     188      PyObject *r = NULL;
     189      int err;
     190  
     191      if (p == (PyObject *)&_PyNone_Type) {
     192          return _PyUnicodeWriter_WriteASCIIString(writer, "None", 4);
     193      }
     194  
     195      if (_PyObject_LookupAttr(p, &_Py_ID(__origin__), &tmp) < 0) {
     196          goto exit;
     197      }
     198  
     199      if (tmp) {
     200          Py_DECREF(tmp);
     201          if (_PyObject_LookupAttr(p, &_Py_ID(__args__), &tmp) < 0) {
     202              goto exit;
     203          }
     204          if (tmp) {
     205              // It looks like a GenericAlias
     206              Py_DECREF(tmp);
     207              goto use_repr;
     208          }
     209      }
     210  
     211      if (_PyObject_LookupAttr(p, &_Py_ID(__qualname__), &qualname) < 0) {
     212          goto exit;
     213      }
     214      if (qualname == NULL) {
     215          goto use_repr;
     216      }
     217      if (_PyObject_LookupAttr(p, &_Py_ID(__module__), &module) < 0) {
     218          goto exit;
     219      }
     220      if (module == NULL || module == Py_None) {
     221          goto use_repr;
     222      }
     223  
     224      // Looks like a class
     225      if (PyUnicode_Check(module) &&
     226          _PyUnicode_EqualToASCIIString(module, "builtins"))
     227      {
     228          // builtins don't need a module name
     229          r = PyObject_Str(qualname);
     230          goto exit;
     231      }
     232      else {
     233          r = PyUnicode_FromFormat("%S.%S", module, qualname);
     234          goto exit;
     235      }
     236  
     237  use_repr:
     238      r = PyObject_Repr(p);
     239  exit:
     240      Py_XDECREF(qualname);
     241      Py_XDECREF(module);
     242      if (r == NULL) {
     243          return -1;
     244      }
     245      err = _PyUnicodeWriter_WriteStr(writer, r);
     246      Py_DECREF(r);
     247      return err;
     248  }
     249  
     250  static PyObject *
     251  union_repr(PyObject *self)
     252  {
     253      unionobject *alias = (unionobject *)self;
     254      Py_ssize_t len = PyTuple_GET_SIZE(alias->args);
     255  
     256      _PyUnicodeWriter writer;
     257      _PyUnicodeWriter_Init(&writer);
     258       for (Py_ssize_t i = 0; i < len; i++) {
     259          if (i > 0 && _PyUnicodeWriter_WriteASCIIString(&writer, " | ", 3) < 0) {
     260              goto error;
     261          }
     262          PyObject *p = PyTuple_GET_ITEM(alias->args, i);
     263          if (union_repr_item(&writer, p) < 0) {
     264              goto error;
     265          }
     266      }
     267      return _PyUnicodeWriter_Finish(&writer);
     268  error:
     269      _PyUnicodeWriter_Dealloc(&writer);
     270      return NULL;
     271  }
     272  
     273  static PyMemberDef union_members[] = {
     274          {"__args__", T_OBJECT, offsetof(unionobject, args), READONLY},
     275          {0}
     276  };
     277  
     278  static PyObject *
     279  union_getitem(PyObject *self, PyObject *item)
     280  {
     281      unionobject *alias = (unionobject *)self;
     282      // Populate __parameters__ if needed.
     283      if (alias->parameters == NULL) {
     284          alias->parameters = _Py_make_parameters(alias->args);
     285          if (alias->parameters == NULL) {
     286              return NULL;
     287          }
     288      }
     289  
     290      PyObject *newargs = _Py_subs_parameters(self, alias->args, alias->parameters, item);
     291      if (newargs == NULL) {
     292          return NULL;
     293      }
     294  
     295      PyObject *res;
     296      Py_ssize_t nargs = PyTuple_GET_SIZE(newargs);
     297      if (nargs == 0) {
     298          res = make_union(newargs);
     299      }
     300      else {
     301          res = PyTuple_GET_ITEM(newargs, 0);
     302          Py_INCREF(res);
     303          for (Py_ssize_t iarg = 1; iarg < nargs; iarg++) {
     304              PyObject *arg = PyTuple_GET_ITEM(newargs, iarg);
     305              Py_SETREF(res, PyNumber_Or(res, arg));
     306              if (res == NULL) {
     307                  break;
     308              }
     309          }
     310      }
     311      Py_DECREF(newargs);
     312      return res;
     313  }
     314  
     315  static PyMappingMethods union_as_mapping = {
     316      .mp_subscript = union_getitem,
     317  };
     318  
     319  static PyObject *
     320  union_parameters(PyObject *self, void *Py_UNUSED(unused))
     321  {
     322      unionobject *alias = (unionobject *)self;
     323      if (alias->parameters == NULL) {
     324          alias->parameters = _Py_make_parameters(alias->args);
     325          if (alias->parameters == NULL) {
     326              return NULL;
     327          }
     328      }
     329      Py_INCREF(alias->parameters);
     330      return alias->parameters;
     331  }
     332  
     333  static PyGetSetDef union_properties[] = {
     334      {"__parameters__", union_parameters, (setter)NULL, "Type variables in the types.UnionType.", NULL},
     335      {0}
     336  };
     337  
     338  static PyNumberMethods union_as_number = {
     339          .nb_or = _Py_union_type_or, // Add __or__ function
     340  };
     341  
     342  static const char* const cls_attrs[] = {
     343          "__module__",  // Required for compatibility with typing module
     344          NULL,
     345  };
     346  
     347  static PyObject *
     348  union_getattro(PyObject *self, PyObject *name)
     349  {
     350      unionobject *alias = (unionobject *)self;
     351      if (PyUnicode_Check(name)) {
     352          for (const char * const *p = cls_attrs; ; p++) {
     353              if (*p == NULL) {
     354                  break;
     355              }
     356              if (_PyUnicode_EqualToASCIIString(name, *p)) {
     357                  return PyObject_GetAttr((PyObject *) Py_TYPE(alias), name);
     358              }
     359          }
     360      }
     361      return PyObject_GenericGetAttr(self, name);
     362  }
     363  
     364  PyObject *
     365  _Py_union_args(PyObject *self)
     366  {
     367      assert(_PyUnion_Check(self));
     368      return ((unionobject *) self)->args;
     369  }
     370  
     371  PyTypeObject _PyUnion_Type = {
     372      PyVarObject_HEAD_INIT(&PyType_Type, 0)
     373      .tp_name = "types.UnionType",
     374      .tp_doc = PyDoc_STR("Represent a PEP 604 union type\n"
     375                "\n"
     376                "E.g. for int | str"),
     377      .tp_basicsize = sizeof(unionobject),
     378      .tp_dealloc = unionobject_dealloc,
     379      .tp_alloc = PyType_GenericAlloc,
     380      .tp_free = PyObject_GC_Del,
     381      .tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC,
     382      .tp_traverse = union_traverse,
     383      .tp_hash = union_hash,
     384      .tp_getattro = union_getattro,
     385      .tp_members = union_members,
     386      .tp_richcompare = union_richcompare,
     387      .tp_as_mapping = &union_as_mapping,
     388      .tp_as_number = &union_as_number,
     389      .tp_repr = union_repr,
     390      .tp_getset = union_properties,
     391  };
     392  
     393  static PyObject *
     394  make_union(PyObject *args)
     395  {
     396      assert(PyTuple_CheckExact(args));
     397  
     398      unionobject *result = PyObject_GC_New(unionobject, &_PyUnion_Type);
     399      if (result == NULL) {
     400          return NULL;
     401      }
     402  
     403      Py_INCREF(args);
     404      result->parameters = NULL;
     405      result->args = args;
     406      _PyObject_GC_TRACK(result);
     407      return (PyObject*)result;
     408  }