File size: 3,996 Bytes
5b832b6
 
7f79b94
5b832b6
 
 
7f79b94
 
 
 
 
 
 
5b832b6
7f79b94
 
 
 
 
 
5b832b6
7f79b94
5b832b6
 
 
 
7f79b94
5b832b6
 
 
7f79b94
5b832b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7f79b94
 
 
 
 
 
 
 
 
5b832b6
 
7f79b94
 
 
 
 
 
 
 
5b832b6
 
 
7f79b94
 
5b832b6
7f79b94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5b832b6
7f79b94
 
 
 
 
 
 
 
 
5b832b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import tempfile
import subprocess
from transformers import Tool
import sys
import os
import re

class PythonInterpreter(Tool):
    name = "python_interpreter_tool"
    description = "Executes Python code and returns results. Requires HF_ALLOW_CODE_EVAL = '1'"

    inputs = ["text"]
    outputs = ["text"]
    timeout = 1.0

    def __call__(self, task: str):

        if os.getenv("HF_ALLOW_CODE_EVAL", 0) != "1":
            return "Can't execute code, set code evaluation flag for PythonInterpreter"

        with tempfile.TemporaryDirectory() as temp_dir:

            code = "from safeguard import reliability_guard\nreliability_guard()\n" + task
            code_file = os.path.join(temp_dir, 'code.py')
            with open(code_file, 'w') as f:
                f.write(code)

            safeguard_file = os.path.join(temp_dir, 'safeguard.py')
            with open(safeguard_file, 'w') as f:
                f.write(reliability_guard_code)

            cmd = f"python {code_file}"
            
            try:
                output = subprocess.check_output(cmd.split(), stderr=subprocess.STDOUT, timeout=self.timeout)
                output_text = output.decode(sys.stdout.encoding).strip()
            except subprocess.TimeoutExpired:
                output_text = "Code execution timed out."
            except subprocess.CalledProcessError as e:
                output_text = e.output.decode(sys.stdout.encoding).strip()
                output_text = output_text.replace(temp_dir + "/", "/tmp/")
                output_text = fix_stacktrace_linenumber(output_text)                
        output_text = output_text.replace(temp_dir + "/", "/tmp/")
        output_text = fix_warning_linenumber(output_text)
        return output_text

reliability_guard_code = """\
def reliability_guard(maximum_memory_bytes=None):
    if maximum_memory_bytes is not None:
        import resource

        resource.setrlimit(resource.RLIMIT_AS, (maximum_memory_bytes, maximum_memory_bytes))
        resource.setrlimit(resource.RLIMIT_DATA, (maximum_memory_bytes, maximum_memory_bytes))
        if not platform.uname().system == "Darwin":
            resource.setrlimit(resource.RLIMIT_STACK, (maximum_memory_bytes, maximum_memory_bytes))

    import faulthandler

    faulthandler.disable()

    import builtins

    builtins.exit = None
    builtins.quit = None

    import os

    os.environ["OMP_NUM_THREADS"] = "1"

    os.kill = None
    os.system = None
    # os.putenv = None # breaks e.g. numpy
    os.remove = None
    os.removedirs = None
    os.rmdir = None
    os.fchdir = None
    os.setuid = None
    os.fork = None
    os.forkpty = None
    os.killpg = None
    os.rename = None
    os.renames = None
    os.truncate = None
    os.replace = None
    os.unlink = None
    os.fchmod = None
    os.fchown = None
    os.chmod = None
    os.chown = None
    os.chroot = None
    os.fchdir = None
    os.lchflags = None
    os.lchmod = None
    os.lchown = None
    os.getcwd = None
    os.chdir = None

    import shutil

    shutil.rmtree = None
    shutil.move = None
    shutil.chown = None

    import subprocess

    subprocess.Popen = None  # type: ignore

    __builtins__["help"] = None

    import sys

    sys.modules["ipdb"] = None
    sys.modules["joblib"] = None
    sys.modules["resource"] = None
    sys.modules["psutil"] = None
    sys.modules["tkinter"] = None
"""

def fix_stacktrace_linenumber(text):
    start = '  File "/tmp/code.py", line '
    lines = text.split("\n")
    fixed_lines = []
    for line in lines:
        if line.startswith(start):
            number, end = line.split(start)[1].split(",")
            new_number = int(number)-2
            line = start + str(new_number) + end
        fixed_lines.append(line)
    return "\n".join(fixed_lines)

def fix_warning_linenumber(string):
    pattern = r'code\.py:(\d+):'
    replaced_string = re.sub(pattern, lambda match: 'code.py:' + str(int(match.group(1))-2) + ':', string)
    return replaced_string