有人说 Python 不支持函数重载?

共 9970字,需浏览 20分钟

 ·

2021-10-26 21:33

众所周知,Python 是动态语言,所谓动态语言,就是变量的类型是动态的,程序运行期间变量类型可以随意变化,由于 Python 的变量是没有固定类型的,而函数重载却依赖变量类型,重载就是定义多个同名函数,但这些同名函数的参数类型不同,传入不同类型的参数时执行与之对应的函数。

Python 的变量没有类型,因此 Python 语法本身不支持函数重载,因此有人说 Python 不支持函数重载这话本身是正确的,不过本文想说的是,Python 动态语言的灵活性根本不需要通过函数重载就可以实现一个函数多个功能。不过要让 Python 真正支持函数重载,也就可以的实现的具体来说有两种方案。

方案一、伪重载

Java 那种重载的好处是从函数的形式上可以看出函数支持哪些变量类型,而 Python 由于变量没有固定的类型,这一点可读性就不太好,比如说下面的函数 fun,其实是支持两种参数,一种是全部是字符串,一种是全部都是整数,但是不看代码的话,其实是看不出来的:

def fun(x, y):
    if isinstance(x, str) and isinstance(y, str):
        print(f"str {x =}{y = } ")
    elif isinstance(x, int) and isinstance(y, int):
        print(f"int {x = }{y = }")
fun("hello""world")
fun(1,2)

运行结果

str x ='hello', y = 'world' 
int x = 1, y = 2

不过好在 Python 有类型提示,借助于 Python 的标准库 typing,我们也可以写出重载形式的代码:

import typing


class A:
    @typing.overload
    def fun(self, x: str, y: str) -> None:
        pass

    @typing.overload
    def fun(self, x: int, y: int) -> None:
        pass

    def fun(self, x, y) -> None:
        if isinstance(x, str) and isinstance(y, str):
            print(f"str {x =}{y = } ")
        elif isinstance(x, int) and isinstance(y, int):
            print(f"int {x = }{y = }")


if __name__ == "__main__":
    a = A()
    a.fun("hello""world")
    a.fun(12)

运行结果:

str x ='hello', y = 'world' 
int x = 1, y = 2

这样的话,可读性就提高了,不过这是一种形式上的重载,真正发挥作用的是最后那个没有装饰器的函数,前面两个带装饰器的函数只是为了更好的可读性而存在,没有实际的作用,可以删除,不影响程序运行。

要想实现 Java 那样真正的函数重载,请看方案二。

方案二,借助元类,实现真正的重载

元类是 Python 比较高级的特性,如果一开始就给完整的代码,你可能看不懂,这里循序渐近的展示实现过程。

Python 中一切皆对象,比如说 1 是 int 的实例,int 是 type 实例:

In [7]: a = 5

In [8]: type(a)
Out[8]: int

In [9]: type(int)
Out[9]: type

In [10]:
In [11]: type??
Init signature: type(self, /, *args, **kwargs)
Docstring:
type(object_or_name, bases, dict)
type(object) -> the object's type
type(name, bases, dict) -> a new type
Type:           type
Subclasses:     ABCMeta, EnumMeta, _TemplateMetaclass, _ABC, MetaHasDescriptors, NamedTupleMeta, _TypedDictMeta, LexerMeta, StyleMeta, _NormalizerMeta, ...

从上述可以看出,type(object) 返回 object 的类型,而  type(name, bases, dict) 会产生一个新的类型,也就是说 type(name, bases, dict) 会产生一个 class:

In [17]: A = type('A',(),{})

In [18]: a = A()

In [19]: type(a)
Out[19]: __main__.A
In [20]: type(A)
Out[20]: type

上面的代码,相当于 :

In [21]: class A:
    ...:     pass
    ...:

In [22]: a = A()

In [23]: type(a)
Out[23]: __main__.A

In [24]: type(A)
Out[24]: type

明白了这一点,即使不使用 class 关键字,我们也可以创建出一个类来,比如说下面的 make_A() 和 A() 的作用是一样的:

class A:
    a = 1
    b = "hello"
    def fun(self):
        return "Class A"

def make_A():
    name = 'A'
    bases = ()
    a = 1
    b = "hello"

    def fun():
        return "Class A"

    namespace = {'a':a,'b':b,'fun': fun}

    return type(name,bases,namespace)

if __name__ == '__main__':
    a = A()
    print(a.b)
    print(a.fun())
    print("==="*5)
    b = make_A()
    print(b.b)
    print(b.fun())

执行结果:

hello
Class A
===============
hello
Class A

请注意上述的 make_A 函数里面有一个 namespace,它是一个字典,存储了类的成员变量和成员函数,当我们在一个类中定义多个同名函数时,最后一个会把前面的全部覆盖掉,这是字典的特性,同一个键多次赋值,只会保留最后一个,因此 Python 类不支持函数重载。

现在我们需要保留多个同名函数,那就要改写这个字典,当出现同一个键多次赋值时,将这些值(函数)保留在一个列表中,具体方法编写一个类,继承 dict,然后编写代码如下:

class OverloadDict(dict):
    def __setitem__(self, key, value):
        assert isinstance(key, str), "keys must be str"

        prior_val = self.get(key, _MISSING)
        overloaded = getattr(value, "__overload__"False)

        if prior_val is _MISSING:
            insert_val = OverloadList([value]) if overloaded else value
            super().__setitem__(key, insert_val)
        elif isinstance(prior_val, OverloadList):
            if not overloaded:
                raise ValueError(self._errmsg(key))
            prior_val.append(value)
        else:
            if overloaded:
                raise ValueError(self._errmsg(key))
            super().__setitem__(key, value)

    @staticmethod
    def _errmsg(key):
        return f"must mark all overloads with @overload: {key}"

上述代码有一个关键的地方,那就是如果有 overload 标识,那么就放在列表 prior_val 中:

elif isinstance(prior_val, OverloadList):
    if not overloaded:
        raise ValueError(self._errmsg(key))
    prior_val.append(value)

其中 OverloadList 就是一个列表,其定义如下:

class OverloadList(list):
    pass

再写个装饰器,标识一个函数是否要重载:

def overload(f):
    f.__overload__ = True
    return 

然后我们来测试下这个 OverloadDict,看看它产生的效果:

print("OVERLOAD DICT USAGE")
d = OverloadDict()

@overload
def f(self):
    pass

d["a"] = 1
d["a"] = 2
d["b"] = 3
d["f"] = f
d["f"] = f
print(d)

运行结果:

OVERLOAD DICT USAGE
{'a': 2, 'b': 3, 'f': [<function overload_dict_usage..f at 0x7fdec70090d0>, <function overload_dict_usage..f at 0x7fdec70090d0>]}

OverloadDict 解决了重名函数的如何保存问题,就是把它们放在一个列表中,还有一个问题没有解决,那就是调用的时候如何从列表中取出正确的那个函数来执行?

肯定是根据函数传入的参数类型作为判断依据,那如何实现呢?借助于 Python 的类型提示及自省模块 inspect,当然了,还要借助 Python 的元类:

class OverloadMeta(type):
    @classmethod
    def __prepare__(mcs, name, bases):
        return OverloadDict()

    def __new__(mcs, name, bases, namespace, **kwargs):
        overload_namespace = {
            key: Overload(val) if isinstance(val, OverloadList) else val
            for key, val in namespace.items()
        }
        return super().__new__(mcs, name, bases, overload_namespace, **kwargs)

这里面有个 Overload 类,它的作用就是将函数的签名和定义作一个映射,当我们使用 a.f 时就会调用 __get__ 方法获取对应的函数。其定义如下:


class Overload:
    def __set_name__(self, owner, name):
        self.owner = owner
        self.name = name

    def __init__(self, overload_list):
        if not isinstance(overload_list, OverloadList):
            raise TypeError("must use OverloadList")
        if not overload_list:
            raise ValueError("empty overload list")
        self.overload_list = overload_list
        self.signatures = [inspect.signature(f) for f in overload_list]

    def __repr__(self):
        return f"{self.__class__.__qualname__}({self.overload_list!r})"

    def __get__(self, instance, _owner=None):
        if instance is None:
            return self
        # don't use owner == type(instance)
        # we want self.owner, which is the class from which get is being called
        return BoundOverloadDispatcher(
            instance, self.owner, self.name, self.overload_list, self.signatures
        )

    def extend(self, other):
        if not isinstance(other, Overload):
            raise TypeError
        self.overload_list.extend(other.overload_list)
        self.signatures.extend(other.signatures)

__get__ 返回的是一个 BoundOverloadDispatcher 类,它把参数类型和对应的函数进行了绑定,只要函数被调用时才会调用 __call__ 返回最匹配的函数进行调用:


class BoundOverloadDispatcher:
    def __init__(self, instance, owner_cls, name, overload_list, signatures):
        self.instance = instance
        self.owner_cls = owner_cls
        self.name = name
        self.overload_list = overload_list
        self.signatures = signatures

    def best_match(self, *args, **kwargs):
        for f, sig in zip(self.overload_list, self.signatures):
            try:
                bound_args = sig.bind(self.instance, *args, **kwargs)
            except TypeError:
                pass  # missing/extra/unexpected args or kwargs
            else:
                bound_args.apply_defaults()
                # just for demonstration, use the first one that matches
                if _signature_matches(sig, bound_args):
                    return f

        raise NoMatchingOverload()

    def __call__(self, *args, **kwargs):
        try:
            f = self.best_match(*args, **kwargs)
        except NoMatchingOverload:
            pass
        else:
            return f(self.instance, *args, **kwargs)

        # no matching overload in owner class, check next in line
        super_instance = super(self.owner_cls, self.instance)
        super_call = getattr(super_instance, self.name, _MISSING)
        if super_call is not _MISSING:
            return super_call(*args, **kwargs)
        else:
            raise NoMatchingOverload()
            
def _type_hint_matches(obj, hint):
    # only works with concrete types, not things like Optional
    return hint is inspect.Parameter.empty or isinstance(obj, hint)


def _signature_matches(sig: inspect.Signature, bound_args: inspect.BoundArguments):
    # doesn't handle type hints on *args or **kwargs
    for name, arg in bound_args.arguments.items():
        param = sig.parameters[name]
        hint = param.annotation
        if not _type_hint_matches(arg, hint):
            return False
    return True

到这里已经差不多了,我们组装一下上面的代码,就可以让 Python 实现真正的重载:

import inspect

class NoMatchingOverload(Exception):
    pass

_MISSING = object()

class A(metaclass=OverloadMeta):
    @overload
    def f(self, x: int):
        print("A.f int overload", self, x)

    @overload
    def f(self, x: str):
        print("A.f str overload", self, x)

    @overload
    def f(self, x, y):
        print("A.f two arg overload", self, x, y)


class B(A):
    def normal_method(self):
        print("B.f normal method")

    @overload
    def f(self, x, y, z):
        print("B.f three arg overload", self, x, y, z)

    # works with inheritance too!


class C(B):
    @overload
    def f(self, x, y, z, t):
        print("C.f four arg overload", self, x, y, z, t)


def overloaded_class_example():
    print("OVERLOADED CLASS EXAMPLE")

    a = A()
    print(f"{a=}")
    print(f"{type(a)=}")
    print(f"{type(A)=}")
    print(f"{A.f=}")

    a.f(0)
    a.f("hello")
    # a.f(None) # Error, no matching overload
    a.f(1True)
    print(f"{A.f=}")
    print(f"{a.f=}")

    b = B()
    print(f"{b=}")
    print(f"{type(b)=}")
    print(f"{type(B)=}")
    print(f"{B.f=}")
    b.f(0)
    b.f("hello")
    b.f(1True)
    b.f(1True"hello")
    # b.f(None)  # no matching overload
    b.normal_method()

    c = C()
    c.f(1)
    c.f(123)
    c.f(1234)
    # c.f(None) # no matching overload


def main():
    overloaded_class_example()


if __name__ == "__main__":
    main()

运行结果如下:


OVERLOADED CLASS EXAMPLE
a=<__main__.A object at 0x7fbabe67d8e0>
type(a)='__main__.A'>
type(A)='__main__.OverloadMeta'>
A.f=Overload([<function A.f at 0x7fbabe679280>, <function A.f at 0x7fbabe679310>, <function A.f at 0x7fbabe6793a0>])
A.f int overload <__main__.A object at 0x7fbabe67d8e0> 0
A.f str overload <__main__.A object at 0x7fbabe67d8e0> hello
A.f two arg overload <__main__.A object at 0x7fbabe67d8e0> 1 True
A.f=Overload([<function A.f at 0x7fbabe679280>, <function A.f at 0x7fbabe679310>, <function A.f at 0x7fbabe6793a0>])
a.f=<__main__.BoundOverloadDispatcher object at 0x7fbabe67d910>
b=<__main__.B object at 0x7fbabe67d910>
type(b)='__main__.B'>
type(B)='__main__.OverloadMeta'>
B.f=Overload([<function B.f at 0x7fbabe6794c0>])
A.f int overload <__main__.B object at 0x7fbabe67d910> 0
A.f str overload <__main__.B object at 0x7fbabe67d910> hello
A.f two arg overload <__main__.B object at 0x7fbabe67d910> 1 True
B.f three arg overload <__main__.B object at 0x7fbabe67d910> 1 True hello
B.f normal method
A.f int overload <__main__.C object at 0x7fbabe67d9a0> 1
B.f three arg overload <__main__.C object at 0x7fbabe67d9a0> 1 2 3
C.f four arg overload <__main__.C object at 0x7fbabe67d9a0> 1 2 3 4

代码比较长,放在一起不利于阅读和理解,但全部代码都在正文中有展示,如果你不想自己组装,就是想要完整可一键运行的代码,可以关注公众号「Python七号」,对话框回复「重载」获取实现 Python 重载的完整代码。

浏览 80
点赞
评论
收藏
分享

手机扫一扫分享

分享
举报
评论
图片
表情
推荐
点赞
评论
收藏
分享

手机扫一扫分享

分享
举报