安全地使用Python中的eval

python中的eval函数非常强大,可以执行输入的字符串动态代码,例如:

>>> eval("1+2")
3
>>> eval("[x for x in range(10)]")
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]

>>> import os
>>> eval("os.system('whoami')")
kyle
0

eval 只能执行 Python 的表达式类型的代码,不能直接用它进行 import 操作,但 exec 可以。如果非要使用 eval 进行 import,则使用__import__:

>>> eval("__import__('os').system('whoami')")

eval 只接受一个表达式字符串代码,并返回执行结果,而 exec 可以接受一个代码块,里面可以有循环,可以有类、函数定义等等,exec执行这个代码块,并永远返回 None。

关于更多 eval exec 和 compile 的区别,可以参考这个帖子:https://stackoverflow.com/questions/2220699/whats-the-difference-between-eval-exec-and-compile-in-python

正因为eval功能的强大,很有可能输入的参数含有恶意利用的代码,造成漏洞,例如用户输入的字符串是这样的

__import__('os').system('dir')  

那么 eval() 之后,你会发现,当前目录文件都会展现在用户前面。或者

open('文件名').read()  
__import__('os').system('rm -rf /')  

我们可以通过传递 __builtins__ 参数来限制内置方法的使用,如

In [2]: print eval("__import__('os').remove('file')", {"__builtins__": {}})
---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
<ipython-input-2-ca308c631e67> in <module>()
----> 1 print eval("__import__('os').remove('file')", {"__builtins__": {}})

<string> in <module>()

NameError: name '__import__' is not defined

不过设置 __builtins__ 也并不保证绝对安全,例如下面这段代码则是退出解释器:

>>> s = """
... [
...     c for c in
...     ().__class__.__bases__[0].__subclasses__()
...     if c.__name__ == "Quitter"
... ][0](0)()
... """
>>> eval(s, {'__builtins__':{}})

D:\>

().__class__.__bases__[0].__subclasses__()

CPU 恶意攻击的例子,执行下面这个eval会让你的cpu长期100%负载运行,整个程序僵死。

eval("2**9999999999**9999999")

还有内存攻击的例子,在 python 2 中运行

eval('(1,' * 100 + ')' * 100)

会报 memory error的错误:

In [3]: eval('(1,' * 100 + ')' * 100)
s_push: parser stack overflow
---------------------------------------------------------------------------
MemoryError Traceback (most recent call last)
in ()
----> 1 eval('(1,' * 100 + ')' * 100)

MemoryError:

在 python 2 中运行下面这个eval,会出现 Segmentation fault (core dumped) 的错误

print eval('(lambda i: [i for i in ((i, 1) for j in range(1000000))][-1])(1)')

那么如何保护eval方法,对它的参数进行安全进校验呢?我在网上看到了两种思路。

第一个方法来自于 https://opensourcehacker.com/2014/10/29/safe-evaluation-of-math-expressions-in-pure-python/

We use the Python compile() function to prepare the Python expression to be the bytecode for evaling. Then, we actually don’t eval(). Instead we use custom opcode handlers and evaluate opcodes in a loop one after another. There cannot be a sandbox escape, because opcodes having should functionality are not implemented. Because compile() does the job of generating the microcode, we are saved from the headache of writing a custom parser. Python dis module helps us minimize the code needed for a stack-based virtual machine.

意思就是不直接执行eval,而是先用compile得到字节码,再自己写方法来一步步地解析执行字节码中的逻辑。

至于CPU攻击,可以先eval过程放到一个子进程中去执行,并且限制住子进程运行的时长,如果超时了,就把它kill掉,由于python的GIL锁限制,我们不能使用子线程来杀掉CPU占用的超时逻辑,所以只能用进程模式。

We do this using Python’s multiprocess module. The calculation is run in a separate child process. This process is terminated if it doesn’t finish timely. In our case, we’ll give 100 milliseconds for the expression calculations to finish.

As a side note we threads didn’t work here because it turned out compile() does not release GIL and thus thread.start() never returns if thread.run() contains a complex compile() – the original thread does not get GIL back.

完整方案的代码实现可以查看:https://gist.github.com/miohtama/34a83d870a14aa7e580d,文末也把代码附上。

第二种方法来自于库 Evalidate,参考:http://evalidate.readthedocs.io/en/latest/

我个人比较喜欢这个方法,代码实现也很简短,思路是通过ast库解析代码的语法树,再逐一分析语法树是不是合法的,我们可以限制哪些语法可以使用。

比如要解决 eval(“2**9999999999**9999999”) 就可以限制乘方运算,而乘方运算对应的ast树名称为 Pow,AST树对应的名称文档可以参考 https://greentreesnakes.readthedocs.io/en/latest/nodes.html

完整代码如下


#!/usr/bin/python

# from http://evalidate.readthedocs.io/en/latest/
"""Safe user-supplied python expression evaluation."""

import ast
import sys

version = '0.6'


class SafeAST(ast.NodeVisitor):

    """AST-tree walker class."""

    allowed = {}

    def __init__(self, safenodes=None, addnodes=None):
        """create whitelist of allowed operations."""
        if safenodes is not None:
            self.allowed = safenodes
        else:
            '''Nodes doc: https://greentreesnakes.readthedocs.io/en/latest/nodes.html'''
            # 123, 'asdf'
            values = ['Num', 'Str']
            # any expression
            expression = ['Expression']
            # == ...
            compare = ['Compare', 'Eq', 'NotEq', 'Gt', 'GtE', 'Lt', 'LtE']
            # variable name
            variables = ['Name', 'Load']
            binop = ['BinOp']
            arithmetics = ['Add', 'Sub', 'Div']
            subscript = ['Subscript', 'Index']  # person['name']
            boolop = ['BoolOp', 'And', 'Or', 'UnaryOp', 'Not']  # True and True
            inop = ["In"]  # "aaa" in i['list']
            ifop = ["IfExp"] # for if expressions, like: expr1 if expr2 else expr3
            nameconst = ["NameConstant"] # for True and False constants

            self.allowed = expression + values + compare + variables + binop + \
                arithmetics + subscript + boolop + inop + ifop + nameconst

        if addnodes is not None:
            self.allowed = self.allowed + addnodes

    def generic_visit(self, node):
        """Check node, rais exception is node is not in whitelist."""
        if type(node).__name__ in self.allowed:
            ast.NodeVisitor.generic_visit(self, node)
        else:
            raise ValueError(
                "Operaton type {optype} is not allowed".format(
                    optype=type(node).__name__))


def evalidate(expression, safenodes=None, addnodes=None):
    """Validate expression.

    return node if it passes our checks
    or pass exception from SafeAST visit.
    """
    node = ast.parse(expression, '<usercode>', 'eval')

    v = SafeAST(safenodes, addnodes)
    v.visit(node)
    return node


def safeeval(src, context={}, safenodes=None, addnodes=None):
    """C-style simplified wrapper, eval() replacement."""
    try:
        node = evalidate(src, safenodes, addnodes)
    except Exception as e:
        return (False, "Validation error: "+e.__str__())

    try:
        code = compile(node, '<usercode>', 'eval')
    except Exception as e:
        return (False, "Compile error: "+e.__str__())

    try:
        wcontext = context.copy()
        result = eval(code, wcontext)
    except Exception as e:
        et, ev, erb = sys.exc_info()
        return False, "Runtime error ({}): {}".format(type(e).__name__, ev)

    return (True, result)


if __name__ == '__main__':

    books = [
        {
            'title': 'The Sirens of Titan',
            'author': 'Kurt Vonnegut',
            'stock': 10,
            'price': 9.71
        },
        {
            'title': 'Cat\'s Cradle',
            'author': 'Kurt Vonnegut',
            'stock': 2,
            'price': 4.23
        },
        {
            'title': 'Chapaev i Pustota',
            'author': 'Victor Pelevin',
            'stock': 0,
            'price': 21.33
        },
        {
            'title': 'Gone Girl',
            'author': 'Gillian Flynn',
            'stock': 5,
            'price': 8.97
        },
    ]

    #src = 'stock>= (5 if price<9 else 0)' src = 'stock>5 or price>10'

    for book in books:
        success, result = safeeval(src, book)
        if success:
            if result:
                print(book)
        else:
            print("ERR: ", result)

    print safeeval("os.system('whoami')")
    print safeeval('import os')
    print safeeval("__import__('os').system('dir')  ")
    print safeeval('a + (3 if 3>b else b) / c', {'a':5, 'b':1, 'c':2})
    print safeeval('a + b**3', {'a':5, 'b':1, 'c':2})

方法一代码如下

""""
    The orignal author: Alexer / #python.fi
"""

import opcode
import dis
import sys
import multiprocessing
import time

# Python 3 required
assert sys.version_info[0] == 3, "No country for old snakes"


class UnknownSymbol(Exception):
    """ There was a function or constant in the expression we don't support. """


class BadValue(Exception):
    """ The user tried to input dangerously big value. """

    MAX_ALLOWED_VALUE = 2**63


class BadCompilingInput(Exception):
    """ The user tried to input something which might cause compiler to slow down. """


class TimeoutException(Exception):
    """ It took too long to compile and execute. """


class RunnableProcessing(multiprocessing.Process):
    """ Run a function in a child process.
    Pass back any exception received.
    """
    def __init__(self, func, *args, **kwargs):
        self.queue = multiprocessing.Queue(maxsize=1)
        args = (func,) + args
        multiprocessing.Process.__init__(self, target=self.run_func, args=args, kwargs=kwargs)

    def run_func(self, func, *args, **kwargs):
        try:
            result = func(*args, **kwargs)
            self.queue.put((True, result))
        except Exception as e:
            self.queue.put((False, e))

    def done(self):
        return self.queue.full()

    def result(self):
        return self.queue.get()


def timeout(seconds, force_kill=True):
    """ Timeout decorator using Python multiprocessing.
    Courtesy of http://code.activestate.com/recipes/577853-timeout-decorator-with-multiprocessing/
    """
    def wrapper(function):
        def inner(*args, **kwargs):
            now = time.time()
            proc = RunnableProcessing(function, *args, **kwargs)
            proc.start()
            proc.join(seconds)
            if proc.is_alive():
                if force_kill:
                    proc.terminate()
                runtime = time.time() - now
                raise TimeoutException('timed out after {0} seconds'.format(runtime))
            assert proc.done()
            success, result = proc.result()
            if success:
                return result
            else:
                raise result
        return inner
    return wrapper


def disassemble(co):
    """ Loop through Python bytecode and match instructions  with our internal opcodes.
    :param co: Python code object
    """
    code = co.co_code
    n = len(code)
    i = 0
    extended_arg = 0
    result = []
    while i < n: op = code[i] curi = i i = i+1 if op >= dis.HAVE_ARGUMENT:
            # Python 2
            # oparg = ord(code[i]) + ord(code[i+1])*256 + extended_arg
            oparg = code[i] + code[i+1] * 256 + extended_arg
            extended_arg = 0
            i = i+2
            if op == dis.EXTENDED_ARG:
                # Python 2
                #extended_arg = oparg*65536L
                extended_arg = oparg*65536
        else:
            oparg = None

        # print(opcode.opname[op])

        opv = globals()[opcode.opname[op].replace('+', '_')](co, curi, i, op, oparg)

        result.append(opv)

    return result

# For the opcodes see dis.py
# (Copy-paste)
# https://docs.python.org/2/library/dis.html

class Opcode:
    """ Base class for out internal opcodes. """
    args = 0
    pops = 0
    pushes = 0
    def __init__(self, co, i, nexti, op, oparg):
        self.co = co
        self.i = i
        self.nexti = nexti
        self.op = op
        self.oparg = oparg

    def get_pops(self):
        return self.pops

    def get_pushes(self):
        return self.pushes

    def touch_value(self, stack, frame):
        assert self.pushes == 0
        for i in range(self.pops):
            stack.pop()


class OpcodeArg(Opcode):
    args = 1


class OpcodeConst(OpcodeArg):
    def get_arg(self):
        return self.co.co_consts[self.oparg]


class OpcodeName(OpcodeArg):
    def get_arg(self):
        return self.co.co_names[self.oparg]


class POP_TOP(Opcode):
    """Removes the top-of-stack (TOS) item."""
    pops = 1
    def touch_value(self, stack, frame):
        stack.pop()


class DUP_TOP(Opcode):
    """Duplicates the reference on top of the stack."""
    # XXX: +-1
    pops = 1
    pushes = 2
    def touch_value(self, stack, frame):
        stack[-1:] = 2 * stack[-1:]


class ROT_TWO(Opcode):
    """Swaps the two top-most stack items."""
    pops = 2
    pushes = 2
    def touch_value(self, stack, frame):
        stack[-2:] = stack[-2:][::-1]


class ROT_THREE(Opcode):
    """Lifts second and third stack item one position up, moves top down to position three."""
    pops = 3
    pushes = 3
    direct = True
    def touch_value(self, stack, frame):
        v3, v2, v1 = stack[-3:]
        stack[-3:] = [v1, v3, v2]


class ROT_FOUR(Opcode):
    """Lifts second, third and forth stack item one position up, moves top down to position four."""
    pops = 4
    pushes = 4
    direct = True
    def touch_value(self, stack, frame):
        v4, v3, v2, v1 = stack[-3:]
        stack[-3:] = [v1, v4, v3, v2]


class UNARY(Opcode):
    """Unary Operations take the top of the stack, apply the operation, and push the result back on the stack."""
    pops = 1
    pushes = 1


class UNARY_POSITIVE(UNARY):
    """Implements TOS = +TOS."""
    def touch_value(self, stack, frame):
        stack[-1] = +stack[-1]


class UNARY_NEGATIVE(UNARY):
    """Implements TOS = -TOS."""
    def touch_value(self, stack, frame):
        stack[-1] = -stack[-1]


class BINARY(Opcode):
    """Binary operations remove the top of the stack (TOS) and the second top-most stack item (TOS1) from the stack. 
They perform the operation, and put the result back on the stack."""
    pops = 2
    pushes = 1


class BINARY_POWER(BINARY):
    """Implements TOS = TOS1 ** TOS."""
    def touch_value(self, stack, frame):
        TOS1, TOS = stack[-2:]
        print(TOS1, TOS)
        if abs(TOS1) > BadValue.MAX_ALLOWED_VALUE or abs(TOS) > BadValue.MAX_ALLOWED_VALUE:
            raise BadValue("The value for exponent was too big")

        stack[-2:] = [TOS1 ** TOS]


class BINARY_MULTIPLY(BINARY):
    """Implements TOS = TOS1 * TOS."""
    def touch_value(self, stack, frame):
        TOS1, TOS = stack[-2:]
        stack[-2:] = [TOS1 * TOS]


class BINARY_DIVIDE(BINARY):
    """Implements TOS = TOS1 / TOS when from __future__ import division is not in effect."""
    def touch_value(self, stack, frame):
        TOS1, TOS = stack[-2:]
        stack[-2:] = [TOS1 / TOS]


class BINARY_MODULO(BINARY):
    """Implements TOS = TOS1 % TOS."""
    def touch_value(self, stack, frame):
        TOS1, TOS = stack[-2:]
        stack[-2:] = [TOS1 % TOS]


class BINARY_ADD(BINARY):
    """Implements TOS = TOS1 + TOS."""
    def touch_value(self, stack, frame):
        TOS1, TOS = stack[-2:]
        stack[-2:] = [TOS1 + TOS]


class BINARY_SUBTRACT(BINARY):
    """Implements TOS = TOS1 - TOS."""
    def touch_value(self, stack, frame):
        TOS1, TOS = stack[-2:]
        stack[-2:] = [TOS1 - TOS]


class BINARY_FLOOR_DIVIDE(BINARY):
    """Implements TOS = TOS1 // TOS."""
    def touch_value(self, stack, frame):
        TOS1, TOS = stack[-2:]
        stack[-2:] = [TOS1 // TOS]


class BINARY_TRUE_DIVIDE(BINARY):
    """Implements TOS = TOS1 / TOS when from __future__ import division is in effect."""
    def touch_value(self, stack, frame):
        TOS1, TOS = stack[-2:]
        stack[-2:] = [TOS1 / TOS]


class BINARY_LSHIFT(BINARY):
    """Implements TOS = TOS1 << TOS."""
    def touch_value(self, stack, frame):
        TOS1, TOS = stack[-2:]
        stack[-2:] = [TOS1 << TOS] class BINARY_RSHIFT(BINARY): """Implements TOS = TOS1 >> TOS."""
    def touch_value(self, stack, frame):
        TOS1, TOS = stack[-2:]
        stack[-2:] = [TOS1 >> TOS]


class BINARY_AND(BINARY):
    """Implements TOS = TOS1 & TOS."""
    def touch_value(self, stack, frame):
        TOS1, TOS = stack[-2:]
        stack[-2:] = [TOS1 & TOS]


class BINARY_XOR(BINARY):
    """Implements TOS = TOS1 ^ TOS."""
    def touch_value(self, stack, frame):
        TOS1, TOS = stack[-2:]
        stack[-2:] = [TOS1 ^ TOS]


class BINARY_OR(BINARY):
    """Implements TOS = TOS1 | TOS."""
    def touch_value(self, stack, frame):
        TOS1, TOS = stack[-2:]
        stack[-2:] = [TOS1 | TOS]


class RETURN_VALUE(Opcode):
    """Returns with TOS to the caller of the function."""
    pops = 1
    final = True
    def touch_value(self, stack, frame):
        value = stack.pop()
        return value


class LOAD_CONST(OpcodeConst):
    """Pushes co_consts[consti] onto the stack.""" # consti
    pushes = 1
    def touch_value(self, stack, frame):
        # XXX moo: Validate type
        value = self.get_arg()
        assert isinstance(value, (int, float))
        stack.append(value)


class LOAD_NAME(OpcodeName):
    """Pushes the value associated with co_names[namei] onto the stack.""" # namei
    pushes = 1
    def touch_value(self, stack, frame):
        # XXX moo: Get name from dict of valid variables/functions
        name = self.get_arg()
        if name not in frame:
            raise UnknownSymbol("Does not know symbol {}".format(name))
        stack.append(frame[name])


class CALL_FUNCTION(OpcodeArg):
    """Calls a function. The low byte of argc indicates the number of positional parameters, the high byte the number of keyword parameters. 
On the stack, the opcode finds the keyword parameters first. For each keyword argument,
 the value is on top of the key. Below the keyword parameters, the positional parameters are on the stack, 
with the right-most parameter on top. Below the parameters, the function object to call is on the stack. Pops all function arguments,
 and the function itself off the stack, and pushes the return value.""" # argc
    pops = None
    pushes = 1

    def get_pops(self):
        args = self.oparg & 0xff
        kwargs = (self.oparg >> 8) & 0xff
        return 1 + args + 2 * kwargs

    def touch_value(self, stack, frame):
        argc = self.oparg & 0xff
        kwargc = (self.oparg >> 8) & 0xff
        assert kwargc == 0
        if argc > 0:
            args = stack[-argc:]
            stack[:] = stack[:-argc]
        else:
            args = []
        func = stack.pop()

        assert func in frame.values(), "Uh-oh somebody injected bad function. This does not happen."

        result = func(*args)
        stack.append(result)


def check_for_pow(expr):
    """ Python evaluates power operator during the compile time if its on constants.
    You can do CPU / memory burning attack with ``2**999999999999999999999**9999999999999``.
    We mainly care about memory now, as we catch timeoutting in any case.
    We just disable pow and do not care about it.
    """
    if "**" in expr:
        raise BadCompilingInput("Power operation is not allowed")


def _safe_eval(expr, functions_and_constants={}, check_compiling_input=True):
    """ Evaluate a Pythonic math expression and return the output as a string.
    The expr is limited to 1024 characters / 1024 operations
    to prevent CPU burning or memory stealing.
    :param functions_and_constants: Supplied "built-in" data for evaluation
    """

    # Some safety checks
    assert len(expr) < 1024

    # Check for potential bad compiler input
    if check_compiling_input:
        check_for_pow(expr)

    # Compile Python source code to Python code for eval()
    code = compile(expr, '', 'eval')

    # Dissect bytecode back to Python opcodes
    ops = disassemble(code)
    assert len(ops) < 1024

    stack = []
    for op in ops:
        value = op.touch_value(stack, functions_and_constants)

    return value


@timeout(0.1)
def safe_eval_timeout(expr, functions_and_constants={}, check_compiling_input=True):
    """ Hardered compile + eval for long running maths.
    Mitigate against CPU burning attacks.
    If some nasty user figures out a way around our limitations to make really really slow calculations.
    """
    return _safe_eval(expr, functions_and_constants, check_compiling_input)


if __name__ == "__main__":

    # Run some self testing

    def test_eval(expected_result, *args):
        result = safe_eval_timeout(*args)
        if result != expected_result:
            raise AssertionError("Got: {} expected: {}".format(result, expected_result))

    test_eval(2, "1+1")
    test_eval(2, "1 + 1")
    test_eval(3, "a + b", dict(a=1, b=2))
    test_eval(2, "max(1, 2)", dict(max=max))
    test_eval(2, "max(a, b)", dict(a=1, b=2, max=max))
    test_eval(3, "max(a, c, b)", dict(a=1, b=2, c=3, max=max))
    test_eval(3, "max(a, max(c, b))", dict(a=1, b=2, c=3, max=max))
    test_eval("2", "str(1 + 1)", dict(str=str))
    test_eval(2.5, "(a + b) / c", dict(a=4, b=1, c=2))

    try:
        test_eval(None, "max(1, 0)")
        raise AssertionError("Should not be reached")
    except UnknownSymbol:
        pass

    # CPU burning
    try:
        test_eval(None, "2**999999999999999999999**9999999999")
        raise AssertionError("Should not be reached")
    except BadCompilingInput:
        pass

    # CPU burning, see out timeoutter works
    try:
        safe_eval_timeout("2**999999999999999999999**9999999999", check_compiling_input=False)
        raise AssertionError("Should not be reached")
    except TimeoutException:
        pass

    try:
        test_eval(None, "1 / 0")
        raise AssertionError("Should not be reached")
    except ZeroDivisionError:
        pass

    try:
        test_eval(None, "(((((((((((((((()")
        raise AssertionError("Should not be reached")
    except SyntaxError:
        #    for i in range(0, 100):
        #      ^
        # SyntaxError: invalid synta
        pass

    try:
        test_eval(None, "")
        raise AssertionError("Should not be reached")
    except SyntaxError:
        # SyntaxError: unexpected EOF while parsing
        pass

    # compile() should not allow multiline stuff
    # http://stackoverflow.com/q/12698028/315168
    try:
        test_eval(None, "for i in range(0, 100):\n    pass", dict(i=-1))
        raise AssertionError("Should not be reached")
    except SyntaxError:
        #    for i in range(0, 100):
        #      ^
        # SyntaxError: invalid synta
        pass

    # No functions allowed
    try:
        test_eval(None, "lamdba x: x+1")
        raise AssertionError("Should not be reached")
    except SyntaxError:
        # SyntaxError: unexpected EOF while parsing
        pass

参考来源:

  • https://github.com/aosabook/500lines/blob/master/template-engine/code/templite.py
  • https://github.com/greysign/pysec/blob/master/safeeval.py
  • http://rinige.com/index.php/archives/571/
  • http://evalidate.readthedocs.io/en/latest/
  • https://opensourcehacker.com/2014/10/29/safe-evaluation-of-math-expressions-in-pure-python/
  • https://gist.github.com/miohtama/34a83d870a14aa7e580d
  • https://greentreesnakes.readthedocs.io/en/latest/nodes.html