张天昀的个人博客

Python自由变量与函数闭包

2021年10月12日

1. 高阶函数

众所周知,在Python的函数中,我们可以访问在函数外部定义的变量:

>>> y = 1
>>> def add_y(x):
...     return x + y
...
>>> add_y(1)
2
>>> y = 2
>>> add_y(1)
3

该代码即等价于λ\lambda calculus中的λ x. x+y.\lambda\ x.\ x + y. 从数学(逻辑)的角度来说,变量yy不是函数的参数,因此称为自由变量。

在高阶函数(Higher-order Function)中,这种访问自由变量的特性非常有用:

>>> def print_highest(highest):
...     def do_print_highest(value):
...         if value > highest:
...             print('new highest', value)
...         return print_highest(count + 1, max(value, highest))
...     return do_print_highest
...
>>> _ = print_highest(0)(10)(20)(15)(25)(20)
new highest 10
new highest 20
new highest 25

通过闭包机制访问自由变量,高阶函数就像有了状态、有了记忆一样。

2. 函数闭包的编译

下文以CPython 3.10为Python的具体实现,来看看自由变量和函数闭包是如何实现的。

我们知道,Python是一门解释执行的编程语言,其解释执行的过程是由CPython的编译单元(compile)将Python代码转换为字节码,然后由求值单元(ceval)进行运算。

Python/symtable.c中,我们可以找到CPython创建符号表的具体代码和说明。简单来说,Python的语义分析需要两趟(two passes)来确定每一个符号的类型(作用域)。

  • 在第一趟中,CPython从抽象语法树(AST)中收集所有符号的信息,例如这个符号对应一个变量,这个符号没有在函数内定义却被使用等等。
  • 在第二趟中,CPython会根据第一趟获得的数据对所有符号进行分析。CPython采用的是栈式的符号表(即一个作用域一张符号表),在查询一个符号时可以查询当前作用域或父作用域的内容。
    • 如果符号在当前作用域没有定义,则说明是自由变量或者是隐式访问的全局符号。
    • 如果显式声明是全局符号(nonlocalglobal关键字),则在符号表中该符号必须存在,否则CPython会抛出符号错误。
    • 当一个函数中的局部变量被子作用域视为自由变量时,就代表这个变量的生命周期可能会超过定义它的函数的生命周期。此时这个变量就不再是局部变量,CPython要求定义它的函数必须用cell去打包它(类似于Java的packed type)。
    • CPython在第二趟的检查的回溯阶段(见analyze_cells函数),会将当前作用域的局部变量集合与子作用域访问的自由变量的集合取交,将其中元素的作用域改为cell。

Python/compile.c中,我们会发现CPython在_PyAST_Compile函数中调用了_PySymtable_Build函数来创建上面提到的符号表,并调用compiler_mod函数来编译当前的Python模块。在具体的函数调用中我们可以找到遍历语句、编译函数的对应代码:

static int
compiler_visit_stmt(struct compiler *c,
                    stmt_ty s)
{
    // omitted...
    switch (s->kind) {
    case FunctionDef_kind:
        return compiler_function(c, s, 0);
    // omitted...
    }
    // omitted...
}

static int
compiler_function(struct compiler *c,
                  stmt_ty s,
                  int is_async)
{
    PyCodeObject *co;
    PyObject *qualname, *docstring = NULL;
    // omitted...
    // call compiler_enter_scope
    // call compiler_add_const to add None to consts
    // call assemble to construct PyCodeObject *co
    // call compiler_exit_scope
    if (!compiler_make_closure(c, co, funcflags, qualname) {
        // omitted...
        return 0;
    }
    // omitted...
}

static int
compiler_make_closure(struct compiler *c,
                      PyCodeObject *co, /* 函数对象 */
                      Py_ssize_t flags,
                      PyObject *qualname)
{
    // omitted...
    if (co->co_nfreevars) {
        int i = co->co_nlocals + co->co_nplaincellvars;
        for (; i < co->co_nlocalsplus; ++i) {
            // omitted...
            int arg;
            if (reftype == CELL) {
                arg = compiler_lookup_arg(c->u->u_cellvars, name);
            }
            else {
                arg = compiler_lookup_arg(c->u->u_freevars, name);
            }
            ADDOP_I(c, LOAD_CLOSURE, arg);
        }
        flags |= 0x08;
        ADDOP_I(c, BUILD_TUPLE, co->co_nfreevars);
    }
    ADDOP_LOAD_CONST(c, (PyObject*)co);
    ADDOP_I(c, MAKE_FUNCTION, flags);
    return 1;
}

其中,compiler_function会调用assemble来创建Python代码对象。这个函数的具体操作就是利用先前生成的Python符号表,来填充代码对象的对应数据。对于自由变量访问,我们关心co_nfreevars,表示CodeObject中访问的自由变量的数量。

compiler_make_closure中,CPython会查找每一个cell变量和自由变量的变量编号,并添加一个对应的LOAD_CLOSURE指令。然后添加一个BUILD_TUPLE指令,将这些cell打包成一个tuple,最后通过MAKE_FUNCTIONflags |= 0x08参数提示求值模块这个函数有一个闭包需要加载。

例如,上文中的print_highest函数对应的字节码为:

>>> import dis
>>> dis.dis(print_highest)
  2           0 LOAD_CLOSURE             0 (highest)
              2 BUILD_TUPLE              1
              4 LOAD_CONST               1 (<code object do_print_highest at 0x1006d22f0, file "<stdin>", line 2>)
              6 LOAD_CONST               2 ('print_highest.<locals>.do_print_highest')
              8 MAKE_FUNCTION            8 (closure)
             10 STORE_FAST               1 (do_print_highest)

  6          12 LOAD_FAST                1 (do_print_highest)
             14 RETURN_VALUE

Disassembly of <code object do_print_highest at 0x1006d22f0, file "<stdin>", line 2>:
  3           0 LOAD_FAST                0 (value)
              2 LOAD_DEREF               0 (highest)
              4 COMPARE_OP               4 (>)
              6 POP_JUMP_IF_FALSE       18

  4           8 LOAD_GLOBAL              0 (print)
             10 LOAD_CONST               1 ('new highest')
             12 LOAD_FAST                0 (value)
             14 CALL_FUNCTION            2
             16 POP_TOP

  5     >>   18 LOAD_GLOBAL              1 (print_highest)
             20 LOAD_GLOBAL              2 (max)
             22 LOAD_FAST                0 (value)
             24 LOAD_DEREF               0 (highest)
             26 CALL_FUNCTION            2
             28 CALL_FUNCTION            1
             30 RETURN_VALUE

3. 函数闭包的执行

Python/ceval.c中,我们可以看到Python是如何执行字节码与MAKE_FUNCTION指令的:

PyObject* _Py_HOT_FUNCTION
_PyEval_EvalFrameDefault(PyThreadState *tstate,
                         InterpreterFrame *frame,
                         int throwflag)
{
    // omitted...
    switch (opcode) {
        // omitted...
        TARGET(MAKE_FUNCTION) {
            PyObject *codeobj = POP();
            PyFunctionObject *func = (PyFunctionObject *)
                PyFunction_New(codeobj, GLOBALS());
            // omitted...
            if (oparg & 0x08) {
                assert(PyTuple_CheckExact(TOP()));
                func->func_closure = POP();
            }
            // omitted...
            PUSH((PyObject *)func);
            DISPATCH();
        }
        // omitted...
    }
    // omitted...
}

结合上面的反编译字节码结果,我们可以理解Python是如何创建并处理函数闭包的:

  1. 首先,在编译阶段,CPython获取符号表信息,确定哪些变量是局部变量(varnames)、cell变量(cellvars)和自由变量(freevars)。并根据是否有自由变量来确定产生的字节码。

    • 需要注意的是:对于cellvars,Python约定由定义它的函数执行MAKE_CELL将其打包(上面的字节码中没有体现这一点,可以自行写一个其他高阶函数试试)。
     0 LOAD_CLOSURE             0 (param_free)
     2 MAKE_CELL                1 (local_free)
     4 LOAD_CLOSURE             1 (local_free)
     6 BUILD_TUPLE              2 # two freevars here
     8 LOAD_CONST               1 (<code object ...>)
    10 LOAD_CONST               2 ('qualname of func')
    12 MAKE_FUNCTION            8 (closure)
  2. 在执行阶段,对于已经打包好的cellvarsfreevars,使用LOAD_CLOSURE指令将cell读取到栈上。然后使用BUILD_TUPLE指令将栈上的n_freevars个cell打包成一个tuple。

  3. 执行LOAD_CONST将函数的代码对象、qualname读取到栈上。

  4. 执行MAKE_FUNCTION,将代码对象和全局符号相结合,生成新的代码对象。此时发现指令的参数中表明这个函数有一个闭包,将闭包从栈上取出并赋值。最后将新的代码对象压回栈上。 在执行了MAKE_FUNCTION之后,就可以通过func.__closure__来访问闭包内容了(本质上,在Python中访问__closure__会执行C里面的func->func_closure)。正如代码所表示的,__closure__是一个cell组成的tuple,可以通过func.__closure__[i].cell_contents来访问第ii个cell的内容。

  5. 执行RETURN_VALUE将函数的代码对象作为返回值返回。

  6. 用户调用返回的函数时,通过LOAD_DEREFSTORE_DEREF指令来读取、存储用cell打包的数据。由于数据是打包在cell中的,因此在调用返回的函数前修改cell中的值可以影响返回函数的行为(有点类似于C++11 Lambda中的捕获引用)。

    例如:SICP课程群中某位群友贴出来的经典(神经)深入理解Python面试题:

    >>> list = [lambda: x for x in range(10)]
    >>> list[0]()
    9  # x是一个自由变量,因此x被打包,循环中最后修改的值为x=9

4. 总结

至此,我们就从Python的底层代码的角度了解了Python是如何实现自由变量和函数闭包的。

  1. 对于不访问自由变量的函数,没有闭包。
  2. 对于访问了自由变量的函数,Python在编译字节码时分析符号并创建闭包。
  3. 对于返回函数的高阶函数,只有当运行这个函数,解释器执行LOAD_CLOSUREBUILD_TUPLEMAKE_FUNCTION指令时才真正创建闭包。

从Python的底层实现为基础,在这个角度上理解Python的高阶函数会更加清晰。