#!/usr/bin/python3

# A collection of software to aid the usage of IPv6
# Copyright (C) 2018-2019  Kasper Dupont

# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation version 3.

# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.

# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

import codecs
import datetime
from hashlib import sha224
import hmac
from os import path, urandom
import random
import socket
import sys

from decode_port import decode_port

# Message types
IPv4 = 'IPv4'
IPv6 = 'IPv6'
PROTOCOL_IDENTIFIER = 1
NO_OP_MESSAGE = 2
TOKEN = 4
TOKEN_ROTATION = 5
REMOTE_AUTHENTICATOR = 6
CONNECTION_TRACKING = 7
PORT_HINT = 8

SWITCH_TO_IANA_PORT = 15


def hex(b):
    print (codecs.encode(b, 'hex').decode('ascii'))


def bytes_to_int(b):
    return int(codecs.encode(b, 'hex'), 16)


class Logger(object):
    def __init__(self, dirname):
        self.dirname = dirname
        self.current_logname = None
        self.logfile = None

    def log(self, message):
        timestamp = datetime.datetime.now()
        logname = str(timestamp.date())
        if logname != self.current_logname:
            self.current_logname = logname
            self.logfile = open(path.join(self.dirname, logname), 'a', 1)
        self.logfile.write('{timestamp} {message}\n'.format(
            timestamp=timestamp, message=message))


class ConnectionTable(object):
    def __init__(self, address, logger):
        NUMBER_OF_PORTS = 1 << 16
        self.address = address[0]
        self.table_by_client = {}
        self.table_by_port = [None] * NUMBER_OF_PORTS
        self.lru_prev = [None] * NUMBER_OF_PORTS
        self.lru_next = [None] * NUMBER_OF_PORTS
        self.logger = logger

        avoid = {3544, 62862, 62863, address[1]}
        port_numbers = [p for p in range(1024, NUMBER_OF_PORTS)
                        if p not in avoid]
        random.shuffle(port_numbers)

        for p1, p2 in zip([0] + port_numbers,
                          port_numbers + [0]):
            self.lru_next[p1] = p2
            self.lru_prev[p2] = p1

        logger.log('Initialized {ipv4}'.format(ipv4=address[0]))

    def log(self, action, connection):
        self.logger.log('{action} {ipv4}:{mapped_port} [{ipv6}]:{port}'.format(
            action=action,
            ipv4=self.address,
            mapped_port=bytes_to_int(connection[2:4]),
            ipv6=socket.inet_ntop(socket.AF_INET6, connection[16:32]),
            port=bytes_to_int(connection[32:]),
        ))

    def insert(self, connection, authenticator):
        port_number = bytes_to_int(connection[2:4])
        if not port_number:
            return

        old_entry = self.table_by_port[port_number]
        if old_entry:
            assert old_entry[0][:4] == connection[:4]
            return old_entry

        client = connection[4:]
        try:
            return self.table_by_client[client]
        except KeyError:
            pass

        entry = (connection, authenticator)
        self.table_by_client[client] = entry
        self.table_by_port[port_number] = entry

        self.move_to_front_of_lru(port_number)
        self.log('Created', connection)

        return entry

    def move_to_front_of_lru(self, index):
        next = self.lru_next[index]
        prev = self.lru_prev[index]

        if next is None:
            return

        self.lru_next[prev] = next
        self.lru_prev[next] = prev

        next = self.lru_next[0]
        self.lru_prev[next] = index
        self.lru_next[index] = next
        self.lru_prev[index] = 0
        self.lru_next[0] = index

    def discard_lru_entry(self):
        new_port = self.lru_prev[0]
        entry = self.table_by_port[new_port]

        if entry is None:
            return new_port

        self.table_by_port[new_port] = None

        connection = entry[0]
        assert bytes_to_int(connection[2:4]) == new_port

        poped_entry = self.table_by_client.pop(connection[4:])
        assert poped_entry is entry

        self.log('Deleted', connection)

        return new_port

    def process_from_nat64(self, p):
        try:
            connection = p.messages[CONNECTION_TRACKING]
            authenticator = p.messages[REMOTE_AUTHENTICATOR]
            p.messages[PROTOCOL_IDENTIFIER]
        except KeyError:
            return

        responses = {}
        for entry in p.connection_entries:
            response = self.insert(*entry)
            if response:
                responses.setdefault(response[1], []).append(response[0])
        if responses:
            return [b''.join([authenticator] + connections)
                    for authenticator, connections in responses.items()]

        if len(connection) != 34:
            return

        return self.insert(connection, authenticator)

    def lookup_v4(self, mapped_port):
        entry = self.table_by_port[mapped_port]
        if entry is not None:
            self.move_to_front_of_lru(bytes_to_int(entry[0][2:4]))
        return entry

    def lookup_v6_or_create(self, client, messages):
        try:
            entry = self.table_by_client[client]
        except KeyError:
            return self.create_entry(client, messages)

        self.move_to_front_of_lru(bytes_to_int(entry[0][2:4]))
        return entry

    def port_hint(self, messages):
        try:
            port_message = messages[PORT_HINT]
        except KeyError:
            return None

        if len(port_message) != 4:
            return None

        port_number = bytes_to_int(port_message[2:4])
        if self.table_by_port[port_number] is None:
            return port_number

    def create_entry(self, client, messages):
        try:
            authenticator = messages[REMOTE_AUTHENTICATOR]
        except KeyError:
            return None
        port = self.port_hint(messages) or self.discard_lru_entry()
        connection = bytes([CONNECTION_TRACKING + 0x30, 16,
                            port >> 8, port & 0xff]) + client
        entry = self.insert(connection, authenticator)
        assert entry[0] is connection
        return entry


class Packet(object):
    def __init__(self, data):
        self.data = data


class IPv4Packet(Packet):
    def get_mapped_port(self):
        header_len = (self.data[0] & 0xf) * 4
        if len(self.data) < header_len + 8:
            return None
        if self.data[9] not in {6, 0x11}:
            return None
        return bytes_to_int(self.data[header_len + 2:header_len + 4])


class IPv6Packet(Packet):
    def is_echo_request(self):
        if len(self.data) < 48:
            return False
        if self.data[6] != 58:
            return False
        if self.data[40] != 128:
            return False
        return True

    def produce_echo_response(self):
        return b''.join([
            self.data[:7], b'\xff',
            self.data[24:40],
            self.data[8:24],
            bytes([0x81, self.data[41], (self.data[42] + 255) & 255]),
            self.data[43:],
        ])

    def get_client_identifier(self):
        if len(self.data) < 48:
            return None
        if self.data[6] not in {6, 0x11}:
            return None
        return self.data[24:36] + self.data[8:24] + self.data[40:42]


class ControlPacket(Packet):
    def __init__(self, data, sender):
        self.data = data
        self.sender = sender
        self.messages = {}
        self.connection_entries = []
        self.parse(data)

    def parse(self, data):
        while len(data) >= 2:
            type_byte = data[0]
            message_type = None

            if 0x30 <= type_byte <= 0x3f:
                message_type = type_byte & 0xf
                message_len = 2 * (data[1] + 1)

            elif 0x45 <= type_byte <= 0x4f:
                header_len = 4 * (type_byte & 0xf)
                if len(data) >= header_len:
                    message_len = bytes_to_int(data[2:4])
                    if message_len >= header_len:
                        message_type = IPv4

            elif 0x60 <= type_byte <= 0x6f:
                if len(data) >= 40:
                    message_type = IPv6
                    message_len = 40 + bytes_to_int(data[4:6])

            if message_type is None:
                message_len = len(data)

            self.messages[message_type] = data[0:message_len]

            if message_type == CONNECTION_TRACKING and message_len == 34:
                try:
                    self.connection_entries.append(
                        (self.messages[CONNECTION_TRACKING],
                         self.messages[REMOTE_AUTHENTICATOR]))
                except KeyError:
                    pass
                    
            data = data[message_len:]


class Attenuation(object):
    """implements a state machine with 3 states to control attenuation.

    This guarantees:
    - 40% attenutation
    - Never drop two packets in a row
    - Never respond to more than two packets in a row
    - No matter how an adversary times spoofed packets every legitimate
      packet has at least 50% chance of receiving a response.

    For etablished peers there is a state machine per peer.
    For non-established peers there is one global state machine.
    """

    def __init__(self):
        self.state = True

    def attenuate(self):
        if self.state is None:
            self.state = ord(urandom(1)) < 128
        else:
            self.state = None if self.state else True
        return self.state


class Peer(object):
    def __init__(self, address):
        self.address = address
        self.token = False
        self.attenuation = None
        self.connection_table = None

    def link(self, server):
        server.peers[self.address] = self
        if self.attenuation is None:
            self.attenuation = Attenuation()
        self.connection_table = server.get_connection_table(self.address)

    def send_cookie(self, server, message):
        attenuation = self.attenuation or server.attenuation
        if attenuation.attenuate():
            return
        server.s.sendto(message, (self.address[0], 9))

    def authenticate(self, server, messages):
        try:
            token = messages[TOKEN]
        except KeyError:
            return False

        if token == self.token:
            return True

        if messages.get(TOKEN_ROTATION) == self.token:
            self.token = token
            return True

        h = hmac.new(server.secret, digestmod=sha224)
        h.update(token + repr(self.address).encode('ascii'))
        cookie = h.digest()

        if messages.get(IPv4, b'').endswith(cookie):
            self.token = token
            self.link(server)
            return True

        try:
            cookie_headers = self.create_cookie_headers(messages)
        except KeyError:
            return False

        self.send_cookie(server, cookie_headers + cookie)
        return False

    def move_to_front_of_lru(self):
        self.connection_table.move_to_front_of_lru(self.address[1])

    def create_cookie_headers(self, messages):
        try:
            return self.create_proto_59_cookie_headers(messages)
        except KeyError:
            return self.create_tracking_cookie_headers(messages)

    def create_tracking_cookie_headers(self, messages):
        messages[PROTOCOL_IDENTIFIER]

        return b''.join([
            messages[REMOTE_AUTHENTICATOR],
            messages[CONNECTION_TRACKING],
            bytes([NO_OP_MESSAGE + 0x30, 224 >> 9]),
        ])

    def create_proto_59_cookie_headers(self, messages):
        return b''.join([
            messages[REMOTE_AUTHENTICATOR],
            bytes([0x60, 0, 0, 0,
                   0, 224 >> 8, 59, 1]),
            messages[IPv6][24:40],
            messages[IPv6][8:24],
        ])


class Server(object):
    def __init__(self, address, logdir):
        self.s = socket.socket(
            socket.AF_INET,
            socket.SOCK_DGRAM,
            socket.IPPROTO_UDP,
        )
        self.s.bind(address)
        self.attenuation = Attenuation()
        self.peers = {}
        self.connection_tables = {}
        self.secret = urandom(32)
        self.logger = Logger(logdir)
        self.logger.log('Listening on {address}'.format(address=address))

    def port_switch(self, p):
        try:
            new_port = decode_port(p.messages[SWITCH_TO_IANA_PORT])
        except KeyError:
            return

        self.port_switch = lambda *x: None
        self.s = socket.socket(
            socket.AF_INET,
            socket.SOCK_DGRAM,
            socket.IPPROTO_UDP,
        )
        self.s.bind(('0.0.0.0', new_port))
        self.logger.log('Switched to port {port}'.format(port=new_port))

    def recv(self):
        return ControlPacket(*self.s.recvfrom(2**16))

    def return_packet_with_connection_entry(self, entry, packet, p):
        if not entry:
            if isinstance(packet, IPv4Packet):
                self.s.sendto(p.messages[TOKEN] + packet.data, p.sender)
            return

        message = b''.join(entry)
        self.s.sendto(message + packet.data, p.sender)

        # Send an extra copy of just the connection tracking data
        # in case the combined packet experienced an MTU issue.
        self.s.sendto(message, p.sender)

    def process_ipv4(self, p, packet, peer):
        mapped_port = packet.get_mapped_port()
        if mapped_port:
            entry = peer.connection_table.lookup_v4(mapped_port)
            self.return_packet_with_connection_entry(entry, packet, p)

    def process_ipv6(self, p, packet, peer):
        client_identifier = packet.get_client_identifier()
        if client_identifier:
            entry = peer.connection_table.lookup_v6_or_create(
                client_identifier, p.messages)
            self.return_packet_with_connection_entry(entry, packet, p)
        elif packet.is_echo_request():
            self.s.sendto(b''.join([
                p.messages.get(REMOTE_AUTHENTICATOR, b''),
                packet.produce_echo_response(),
            ]), p.sender)

    def get_peer(self, address):
        try:
            return self.peers[address]
        except KeyError:
            return Peer(address)

    def get_connection_table(self, address):
        try:
            # This branch exists as a performance optimization in case the
            # key already exists in the dictionary.
            return self.connection_tables[address[0]]
        except KeyError:
            # This setdefault call would achieve the desired result even if
            # the key already exists in the dictionary. But it would incur
            # the added cost of initializing a new ConnectionTable instance
            # only to throw it away without using it.
            return self.connection_tables.setdefault(
                address[0], ConnectionTable(address, self.logger))

    def run(self):
        while True:
            self.process_packet(self.recv())

    def process_packet(self, p):
        self.port_switch(p)

        peer = self.get_peer(p.sender)
        if not peer.authenticate(self, p.messages):
            return
        peer.move_to_front_of_lru()

        if IPv4 in p.messages:
            self.process_ipv4(p, IPv4Packet(p.messages[IPv4]), peer)
        if IPv6 in p.messages:
            self.process_ipv6(p, IPv6Packet(p.messages[IPv6]), peer)

        response = peer.connection_table.process_from_nat64(p)
        if response:
            self.s.sendto(b''.join(response), p.sender)


def main(argv):
    s = Server(('0.0.0.0', 62863), 'logs')
    s.run()


if __name__ == '__main__':
    main(sys.argv)
