Unravelling rich comparison operators

For the next part of my blog series on pulling apart Python's syntactic sugar, I'm going to be tackling rich comparison operators: ==, !=, >, <, >=, <=.

For this post I am going to be picking apart the example of a > b.

Looking at the bytecode

Using the dis module, we can look at the bytecode that CPython generates:

>>> def spam(): a > b
...
>>> import dis; dis.dis(spam)
  1           0 LOAD_GLOBAL              0 (a)
              2 LOAD_GLOBAL              1 (b)
              4 COMPARE_OP               4 (>)
              6 POP_TOP
              8 LOAD_CONST               0 (None)
             10 RETURN_VALUE
Bytecode for a < b

That points us at the COMPARE_OP opcode. Its implementation sends us to the cmp_outcome() function who delegates all the heavy lifting to PyObject_RichCompare().

How rich comparisons work

With PyObject_RichCompare() delegating to do_richcompare(), the code matches up to the explanation in the data model. Each comparison operator has a matching special/magic method:

  • ==: __eq__
  • !=: __ne__
  • <: __lt__
  • >: __gt__
  • <=: __le__
  • >=: __ge__

So for our a > b example we care about __gt__. That leads us to writing the following Python code to implement the equivalent of operator.gt() (know that debuiltins._mro_getattr() is just a helper to look up attributes on types as Python always does for special/magic methods; it's a perf thing):

def __gt__(lhs, rhs, /):
    lhs_type = type(lhs)
    try:
        lhs_method = debuiltins._mro_getattr(lhs_type, "__gt__")
    except AttributeError:
        pass
    else:
        result = lhs_method(lhs, rhs)
        if result is not NotImplemented:
            return result
    raise TypeError(
        f"'>' not supported between instances of {lhs_type!r} and {type(rhs)!r}"
    )
Implementation of operator.gt() without worrying about the right-hand object

Now each comparison has a reflection so that if the left-hand side of the comparison expression doesn't implement the appropriate special method you at least have a chance at using the right-hand side to get what you. want. The pairings are:

  • __lt__ and __gt__
  • __le__ and __ge__
  • __eq__ and itself
  • __ne__ and itself

What this means (roughly) is that if a > b doesn't work then we can try b < a. Now the data model, much like with binary arithmetic operators, has some fanciness to it when it comes to the right-hand side of the expression. If:

  1. The right-hand side is not the same type as the left-hand side
  2. But the right-hand side's is a subclass of the left-hand's type

then we try the right-hand side's way of doing things first (e.g. b < a). The reason for this rule is just like with binary arithmetic operators: if subclasses on the right-hand side want to do something special they get a chance to. For example, if b wanted to make sure to return an instance of itself it would only get that chance if a did not go first, else a > b could return an instance of a instead of b.

Putting this all together gets us:

def __gt__(lhs, rhs, /):
    lhs_type = type(lhs)
    try:
        lhs_method = debuiltins._mro_getattr(lhs_type, "__gt__")
    except AttributeError:
        lhs_method = _MISSING

    rhs_type = type(rhs)
    try:
        rhs_method = debuiltins._mro_getattr(rhs_type, "__lt__")
    except AttributeError:
        rhs_method = _MISSING

    call_lhs = lhs, lhs_method, rhs
    call_rhs = rhs, rhs_method, lhs

    if (
        rhs_type is not _MISSING  # Do we care?
        and rhs_type is not lhs_type  # Could RHS be an actual subclass?
        and issubclass(rhs_type, lhs_type)  # Is RHS a subclass?
    ):
        calls = call_rhs, call_lhs
    else:
        calls = call_lhs, call_rhs

    for first_obj, meth, second_obj in calls:
        if meth is _MISSING:
            continue
        value = meth(first_obj, second_obj)
        if value is not NotImplemented:
            return value
    else:
        raise TypeError(
            f"unsupported operand type(s) for '>': {lhs_type!r} and {rhs_type!r}"
        )
Implementation of operator.gt() where the reflected operation is considered.

If you generalize this out to the other comparisons and their reflection you have the operations work appropriately for either argument!

== and != can never fail

So we have a solution for > which can be generalized, but there's one more thing we need to contend with. In case you weren't aware, both == and != will not raise TypeError if the special/magic methods don't (if they are even defined). Instead, Python will fall back on comparing the values of id() for each object as appropriate.

Back in the Python 2 days, you could compare any objects using any comparison operator and you would get a result. But those semantics led to odd cases where bad data in a list, for instance, would still be sortable. By making only == and != always succeed (unless their special methods raise an exception), you prevent such unexpected interactions between objects and having silent errors pass (although some people wish even this special case for == and != didn't exist).

And with that, we get a complete implementation for rich comparisons!

def _create_rich_comparison(
    operator: str, name: str, reflection: str, default: Callable[[str, Any, Any], bool]
) -> Callable[[Any, Any], Any]:
    """Create a rich comparison function.

    The 'operator' parameter is the human-readable symbol of the operation (e.g.
    `>`). The 'name' parameter is the primary function (e.g. __gt__), while
    'reflection' is the reflection of that function (e.g. __lt__). The 'default'
    parameter is a callable to use when both functions don't exist and/or return
    NotImplemented.

    """

    def _rich_comparison(lhs: Any, rhs: Any, /) -> Any:
        lhs_type = type(lhs)
        try:
            lhs_method = debuiltins._mro_getattr(lhs_type, name)
        except AttributeError:
            lhs_method = _MISSING

        rhs_type = type(rhs)
        try:
            rhs_method = debuiltins._mro_getattr(rhs_type, reflection)
        except AttributeError:
            rhs_method = _MISSING

        call_lhs = lhs, lhs_method, rhs
        call_rhs = rhs, rhs_method, lhs

        if _is_proper_subclass(rhs_type, lhs_type):
            calls = call_rhs, call_lhs
        else:
            calls = call_lhs, call_rhs

        for first_obj, meth, second_obj in calls:
            if meth is _MISSING:
                continue
            value = meth(first_obj, second_obj)
            if value is not NotImplemented:
                return value
        else:
            return default(operator, lhs, rhs)

    _rich_comparison.__name__ = _rich_comparison.__qualname__ = name
    _rich_comparison.__doc__ = f"Implement the rich comparison `a {operator} b`."
    return _rich_comparison


def _rich_comparison_unsupported(operator: str, lhs: Any, rhs: Any) -> None:
    """Raise TypeError when a rich comparison how no fallback logic."""
    raise TypeError(
        f"unsupported operand type(s) for {operator!r}: {type(lhs)!r} and {type(rhs)!r}"
    )


gt = __gt__ = _create_rich_comparison(
    ">", "__gt__", "__lt__", _rich_comparison_unsupported
)
# ... other rich comparisons ...
eq = __eq__ = _create_rich_comparison(
    "==", "__eq__", "__eq__", lambda _, a, b: id(a) == id(b)
)
Function to create rich comparison operators

__eq__ and __ne__ on object

If you look at how object implements rich comparison, you will see it implements __eq__ and __ne__ (the other special methods for rich comparison on object are just a side-effect of using a single C function to implement all rich comparison special methods). For __eq__, the code does an id() check much like the default semantics for == and when it succeeds it returns True, but if the IDs differ then NotImplemented is returned. The reason for this interesting false result is to allow the other object's __eq__ to participate in the operation, otherwise it falls through to the default semantics for == which eventually return False.

For __ne__, the data model explicitly states that "__ne__() delegates to __eq__() and inverts the result unless it is NotImplemented". That lets you just override __eq__ and get changed semantics for __ne__ automatically. The NotImplemented result has the same effect as == where it will let the default semantics take over for checking the IDs of the objects don't match.

class object:

    def __eq__(self, other, /) -> Union[Literal[True], NotImplemented]:
        """Implement equality via identity.

        If the objects are not equal then return NotImplemented to give the
        other object's __eq__ implementation a chance to participate in the
        comparison.

        """
        # https://github.com/python/cpython/blob/v3.8.3/Objects/typeobject.c#L3834-L3880
        return (self is other) or NotImplemented

    def __ne__(self, other, /) -> Union[bool, NotImplemented]:
        """Implement inequality by delegating to __eq__."""
        # https://github.com/python/cpython/blob/v3.8.3/Objects/typeobject.c#L3834-L3880
        result = self.__eq__(other)
        if result is not NotImplemented:
            return not result
        else:
            return NotImplemented
Implementations of __eq__ and __ne__ for object

As for why bother having these methods predefined on object when the default semantics for == and != do the same thing, it still lets you call the methods directly. Something I think a lot of people don't think about is the fact that you can not only call these methods directly to skip over the syntax, but pass them around like any other object. That's handy if you're passing methods around as callbacks or something.

Conclusion

And that's it! As with the other posts in the series, you can find the source code in my desugar project.