aboutsummaryrefslogtreecommitdiff
path: root/test/haketilo_test/server.py
blob: 19d4a01c50df5b9624f5fd6d3d41341a8c08a396 (about) (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
# SPDX-License-Identifier: AGPL-3.0-or-later

"""
A modular "virtual network" proxy,
wrapping the classes in proxy_core.py
"""

# This file is part of Haketilo.
#
# Copyright (C) 2021 jahoti <jahoti@tilde.team>
# Copyright (C) 2021 Wojtek Kosior <koszko@koszko.org>
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
#
#
# I, Wojtek Kosior, thereby promise not to sue for violation of this
# file's license. Although I request that you do not make use of this code
# in a proprietary program, I am not going to enforce this in court.

from pathlib import Path
from urllib.parse import parse_qs
from threading import Thread
import traceback

from selenium.webdriver.common.utils import free_port

from .proxy_core import ProxyRequestHandler, ThreadingHTTPServer
from .misc_constants import *
from .world_wide_library import catalog as internet

class RequestHijacker(ProxyRequestHandler):
    def handle_request(self, req_body):
        path_components = self.path.split('?', maxsplit=1)
        path = path_components[0]
        try:
            # Response format: (status_code, headers (dict. of strings),
            #       body as bytes or filename containing body as string)
            if path in internet:
                info = internet[path]
                if type(info) is tuple:
                    status_code, headers, body_file = info
                    resp_body = b''
                    if body_file is not None:
                        if 'Content-Type' not in headers:
                            ext = body_file.suffix[1:]
                            if ext and ext in mime_types:
                                headers['Content-Type'] = mime_types[ext]

                        with open(body_file, mode='rb') as f:
                            resp_body = f.read()
                else:
                    # A function to evaluate to get the response
                    get_params, post_params = {}, {}
                    if len(path_components) == 2:
                        get_params = parse_qs(path_components[1])

                    # Parse POST parameters; currently only supports
                    # application/x-www-form-urlencoded
                    if req_body:
                        post_params = parse_qs(req_body.encode())

                    status_code, headers, resp_body = info(self.command, get_params, post_params)
                    if type(resp_body) == str:
                        resp_body = resp_body.encode()

                if type(status_code) != int or status_code <= 0:
                    raise Exception('Invalid status code %r' % status_code)

                for header, header_value in headers.items():
                    if type(header) != str:
                        raise Exception('Invalid header key %r' % header)

                    elif type(header_value) != str:
                        raise Exception('Invalid header value %r' % header_value)
            else:
                status_code, headers = 404, {'Content-Type': 'text/plain'}
                resp_body = b'Handler for this URL not found.'

        except Exception:
            status_code = 500
            headers     = {'Content-Type': 'text/plain'}
            resp_body   = b'Internal Error:\n' + traceback.format_exc().encode()

        headers['Content-Length'] = str(len(resp_body))
        self.send_response(status_code)
        for header, header_value in headers.items():
            self.send_header(header, header_value)

        self.end_headers()
        if resp_body:
            self.wfile.write(resp_body)

def do_an_internet(certdir=default_cert_dir, port=None):
    """Start up the proxy/server"""
    if port is None:
        port = free_port()

    class RequestHijackerWithCertdir(RequestHijacker):
        def __init__(self, *args, **kwargs):
            super().__init__(*args, certdir=certdir, **kwargs)

    httpd = ThreadingHTTPServer(('', port), RequestHijackerWithCertdir)
    Thread(target=httpd.serve_forever).start()

    return httpd