aboutsummaryrefslogtreecommitdiff
path: root/src/hourly.py
blob: ca9fc2f8a45caac9b3cb2fb92adedcfede79e09d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
#!/bin/python3

import subprocess
from os import path, waitpid, unlink, WEXITSTATUS, chown
from time import gmtime, strftime, sleep
import re
from pathlib import Path
from pwd import getpwnam

import psycopg2

# our own module used by several scripts in the project
from ztdnslib import start_db_connection, \
    get_default_host_address, get_ztdns_config, log, set_loghour, logfile

wrapper = '/var/lib/0tdns/vpn_wrapper.sh'
perform_queries = '/var/lib/0tdns/perform_queries.py'
lockfile = '/var/lib/0tdns/lockfile'

def sync_ovpn_config(cursor, vpn_id, config_path, config_hash):
    cursor.execute('''
    select ovpn_config
    from user_side_vpn
    where id = %s and ovpn_config_sha256 = %s
    ''', (vpn_id, config_hash))

    (config_contents,) = cursor.fetchone()

    with open(config_path, "wb") as config_file:
        config_file.write(config_contents.tobytes())

def get_vpn_connections(cursor, hour):
    # return (
    #     # vpn_id | config_path
    #     (14,       "./vpngate_178.254.251.12_udp_1195.ovpn"),
    #     (13,       "./vpngate_public-vpn-229.opengw.net_tcp_443.ovpn")
    # )
    cursor.execute('''
    SELECT DISTINCT v.id, v.ovpn_config_sha256
    FROM user_side_queries AS q JOIN user_side_vpn AS v
    ON v.id = q.vpn_id;
    ''')
    return cursor.fetchall()

# return True on success and False if lock exists
def lock_on_file():
    try:
        with open(lockfile, 'x'):
            return True
    except FileExistsError:
        return False

# return True on success and False if lock got removed in the meantime
def unlock_on_file():
    try:
        unlink(lockfile)
        return True
    except FileNotFoundError:
        return False

address_range_regex = re.compile(r'''
([\d]+\.[\d]+\.[\d]+\.[\d]+) # first IPv4 address in the range

[\s]*-[\s]*                  # dash (with optional whitespace around)

([\d]+\.[\d]+\.[\d]+\.[\d]+) # last IPv4 address in the range
''', re.VERBOSE)

address_regex = re.compile(r'([\d]+)\.([\d]+)\.([\d]+)\.([\d]+)')

def ip_address_to_number(address):
    match = address_regex.match(address)
    if not match:
        return None
    number = 0
    for byte in match.groups():
        byteval = int(byte)
        if byteval > 256:
            return None
        number = number * 256 + byteval
    return number

def number_to_ip_address(number):
    byte1 = number % 256
    number = number // 256
    byte2 = number % 256
    number = number // 256
    byte3 = number % 256
    number = number // 256
    byte4 = number % 256
    return "{}.{}.{}.{}".format(byte4, byte3, byte2, byte1)

# this functions accepts list of IPv4 address ranges like:
#     ['10.25.25.0 - 10.25.25.59', '10.25.25.120 - 10.25.25.135']
# and returns a set of /30 subnetworks; each subnetwork is represented
# by a tuple of 2 usable addresses within that subnetwork.
# E.g. for subnetwork 10.25.25.16/30 it would be ('10.25.25.17', '10.25.25.18');
# Addressess ending with .16 (subnet address)
# and .19 (broadcast in the subnet) are considered unusable in this case.
# The returned set will contain up to count elements.
def get_available_subnetworks(count, address_ranges):
    available_subnetworks = set()

    for address_range in address_ranges:
        match = address_range_regex.match(address_range)
        ok_flag = True

        if not match:
            ok_flag = False

        if ok_flag:
            start_addr_number = ip_address_to_number(match.groups()[0])
            end_addr_number = ip_address_to_number(match.groups()[1])
            if not start_addr_number or not end_addr_number:
                ok_flag = False

        if ok_flag:
            # round so that start_addr is first ip address in a /30 network
            # and end_addr is last ip address in a /30 network
            while start_addr_number % 4 != 0:
                start_addr_number += 1
            while end_addr_number % 4 != 3:
                end_addr_number -= 1

            if start_addr_number >= end_addr_number:
                log("address range '{}' doesn't contain any /30 subnetworks"\
                    .format(address_range))
            else:
                while len(available_subnetworks) < count and \
                      start_addr_number < end_addr_number:
                    usable_addr1 = number_to_ip_address(start_addr_number + 1)
                    usable_addr2 = number_to_ip_address(start_addr_number + 2)
                    available_subnetworks.add((usable_addr1, usable_addr2))
                    start_addr_number += 4
        else:
            log("'{}' is not a valid address range".format(address_range))

    return available_subnetworks

def do_hourly_work(hour):
    ztdns_config = get_ztdns_config()
    if ztdns_config['enabled'] not in ['yes', True]:
        log('0tdns not enabled in the config - exiting')
        return

    connection = start_db_connection(ztdns_config)
    cursor = connection.cursor()

    vpns = get_vpn_connections(cursor, hour)

    handled_vpns = ztdns_config.get('handled_vpns')
    if handled_vpns:
        log('Only handling vpns of ids {}'.format(handled_vpns))
        vpns = [vpn for vpn in vpns if vpn[0] in handled_vpns]
    else:
        # if not specfied in the config, all vpns are handled
        handled_vpns = [vpn[0] for vpn in vpns]

    parallel_vpns = ztdns_config['parallel_vpns'] # we need this many subnets
    subnets = get_available_subnetworks(parallel_vpns,
                                        ztdns_config['private_addresses'])

    if not subnets:
        log("couldn't get ANY /30 subnet of private"
            " addresses from the 0tdns config file - exiting");
        cursor.close()
        connection.close()
        return

    if len(subnets) < parallel_vpns:
        log('configuration allows running {0} parallel vpn connections, but'
            ' provided private ip addresses give only {1} /30 subnets, which'
            ' limits parallel connections to {1}'\
            .format(parallel_vpns, len(subnets)))
        parallel_vpns = len(subnets)
    
    for vpn_id, config_hash in vpns:
        config_path = "/var/lib/0tdns/{}.ovpn".format(config_hash)
        if not path.isfile(config_path):
            log('Syncing config for vpn {} with hash {}'\
                .format(vpn_id, config_hash))
            sync_ovpn_config(cursor, vpn_id, config_path, config_hash)

    # map of each wrapper pid to tuple containing id of the vpn it connects to
    # and subnet (represented as tuple of addresses) it uses for veth device
    pids_wrappers = {}

    def wait_for_wrapper_process():
        while True:
            pid, exit_status = waitpid(0, 0)
            # make sure it's one of our wrapper processes
            vpn_id, subnet, _ = pids_wrappers.get(pid, (None, None, None))
            if subnet:
                break

        exit_status = WEXITSTATUS(exit_status) # read man waitpid if wondering
        if exit_status != 0:
            if exit_status == 2:
                # this means our perform_queries.py crashed... not good
                log('performing queries through vpn {} failed'.format(vpn_id))
                result_info = 'internal failure: process_crash'
            else:
                # vpn server is probably not responding
                log('connection to vpn {} failed'.format(vpn_id))
                result_info = 'internal failure: vpn_connection_failure'

            try:
                cursor.execute('''
                INSERT INTO user_side_responses
                    (date, result, dns_id, service_id, vpn_id)
                (SELECT TIMESTAMP WITH TIME ZONE %s, %s,
                        dns_id, service_id, vpn_id
                FROM user_side_queries
                WHERE vpn_id = %s);
                ''', (hour, result_info, vpn_id))
            except psycopg2.IntegrityError:
                log('results already exist for vpn {}'.format(vpn_id))

        pids_wrappers.pop(pid)
        subnets.add(subnet)

    for vpn_id, config_hash in vpns:
        if len(pids_wrappers) == parallel_vpns:
            wait_for_wrapper_process()

        config_path = "/var/lib/0tdns/{}.ovpn".format(config_hash)
        physical_ip = get_default_host_address(ztdns_config['host'])
        subnet = subnets.pop()
        veth_addr1, veth_addr2 = subnet
        route_through_veth = ztdns_config['host'] + "/32"
        command_in_namespace = [perform_queries, hour, str(vpn_id)]
        log('Running connection for vpn {}'.format(vpn_id))

        # see into vpn_wrapper.sh for explaination of its arguments
        p = subprocess.Popen([wrapper, config_path, physical_ip, veth_addr1,
                              veth_addr2, route_through_veth, str(vpn_id)] +
                             command_in_namespace)

        # we're not actually using the subprocess object anywhere, but we
        # put it in the dict regardless to keep a reference to it - otherwise
        # python would reap the child for us and waitpid(0, 0) would raise
        # '[Errno 10] No child processes' :c
        pids_wrappers[p.pid] = (vpn_id, subnet, p)

    while len(pids_wrappers) > 0:
        wait_for_wrapper_process()

    cursor.close()
    connection.close()


def prepare_logging(hour):
    set_loghour(hour) # log() function will now prepend messages with hour

    Path(logfile).touch() # ensure logfile exists

    # enable 0tdns user to write to logfile
    chown(logfile, getpwnam('0tdns').pw_uid, -1)

# round down to an hour - this datetime format is one
# of the formats accepted by postgres
hour = strftime('%Y-%m-%d %H:00%z', gmtime())
prepare_logging(hour)

if not lock_on_file():
    log('Failed trying to run for {}; {} exists'.format(hour, lockfile))
else:
    try:
        log('Running for {}'.format(hour))
        do_hourly_work(hour)
    finally:
        if not unlock_on_file():
            log("Can't remove lock - {} already deleted!".format(lockfile))