#!/usr/bin/python3

# Copyright (C) 2024 Kasper Dupont

# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation version 3.

# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.

# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

import fcntl
import os
import select
import socket
import struct
import subprocess
import sys


class Tun(object):
    def __init__(self):
        IFF_TUN = 1
        TUNSETIFF = 0x400454ca

        self.fd = os.open("/dev/net/tun", os.O_RDWR)
        ifs = fcntl.ioctl(self.fd, TUNSETIFF,
                          struct.pack("16sH", b"same%d", IFF_TUN))

        self.dev_name = ifs[:16].decode("iso-8859-1").strip("\x00")

    def fileno(self):
        return self.fd

    def ip(self, *args):
        subprocess.call(("/sbin/ip",) + args)

    def set_ip(self):
        self.ip("link", "set", self.dev_name, "up")
        self.ip("address", "add", "100.100.0.1/22", "dev", self.dev_name)
        with open("/proc/sys/net/ipv4/ip_forward", "w") as f:
            f.write("1")

    def recv(self):
        return os.read(self.fd, 65535)

    def send(self, data):
        os.write(self.fd, b"\x00\x00\x08\x00" + data)


class SameHost:
    def suffix_to_ip(self, suffix):
        return bytes([100, 100, suffix >> 8, suffix & 0xff])

    def ping(self, src_ip, dst_ip):
        header = bytes([0x45, 0, 0, 20 + 8,
                        0x33, 0x33, 0, 0,
                        255, socket.IPPROTO_ICMP])
        ips = src_ip + dst_ip
        csum = (0xffff - int.from_bytes(header + ips, "big") % 0xffff
        ).to_bytes(2, "big")
        icmp = bytes([8, 0, 0xf7, 0xff, 0, 0, 0, 0])
        self.tun.send(header + csum + ips + icmp)

    def get_responses(self, timeout):
        while select.select([self.tun], [], [], timeout)[0]:
            packet = self.tun.recv()[4:]
            ip_id = packet[4:6]
            ips = packet[12:20]
            self.responses.append((ip_id, ips))
            timeout = 0

    def re_run(self):
        old = sorted(self.responses)
        self.responses = []

        for _, ips in old:
            self.ping(ips[4:], ips[:4])
            self.get_responses(0.1)

        print(len(old), len(self.responses))

    def get_candidates(self):
        candidates = set()
        prev = -9**9
        for ip_id, ips in self.responses:
            value = int.from_bytes(ip_id, "big")
            if abs(value - prev) < 3:
                if prev_ips[:4] != ips[:4]:
                    candidates.add((prev_ips, ips))
            prev = value
            prev_ips = ips

        return candidates

    def print_ips(self, ips):
        print(socket.inet_ntop(socket.AF_INET, ips[4:]),
              ">",
              socket.inet_ntop(socket.AF_INET, ips[:4]))

    def get_ip_id(self, ips):
        self.get_responses(0.0001)
        self.responses = []

        while True:
            self.ping(ips[4:], ips[:4])
            self.get_responses(0.1)
            for ip_id, response_ips in self.responses:
                if response_ips == ips:
                    return ip_id

    def delta(self, id1, id2):
        id1 = int.from_bytes(id1, "big")
        id2 = int.from_bytes(id2, "big")
        r = id2 - id1
        if (r < -0xf000):
            r += 0x10000
        return r

    def evaluate(self, ips1, ips2):
        print("Evaluating")
        self.print_ips(ips1)
        self.print_ips(ips2)

        prev_id = self.get_ip_id(ips2)

        consecutive1 = 0
        consecutive2 = 0

        while consecutive1 < 3 and consecutive2 < 3:
            id1 = self.get_ip_id(ips1)
            id2 = self.get_ip_id(ips2)
            delta1 = self.delta(prev_id, id1)
            delta2 = self.delta(id1, id2)
            if delta1 <= 0:
                return
            if delta2 <= 0:
                return
            if delta1 > 0x1000:
                return
            if delta2 > 0x1000:
                return
            if delta1 == 1:
                consecutive1 += 1
            if delta2 == 1:
                consecutive2 += 1
            if id2 <= id1:
                return
            prev_id = id2

        ip1 = socket.inet_ntop(socket.AF_INET, ips1[:4])
        ip2 = socket.inet_ntop(socket.AF_INET, ips2[:4])

        print("!!!!!! Found shared IPID counter "
              f"between {ip1} and {ip2} !!!!!!!")

    def main(self, argv):
        ips = [socket.inet_pton(socket.AF_INET, ip) for ip in argv[1:]]
        self.tun = Tun()
        self.tun.set_ip()
        self.responses = []

        for suffix in range(4, 1020):
            for ip in ips:
                self.ping(self.suffix_to_ip(suffix), ip)
                self.get_responses(0.1)

        print(len(self.responses))
        self.re_run()
        self.re_run()

        candidates1 = self.get_candidates()
        print(len(candidates1))

        self.re_run()
        self.re_run()
        candidates2 = self.get_candidates()
        print(len(candidates2))

        for candidate in candidates1.intersection(candidates2):
            self.evaluate(*candidate)


if __name__ == "__main__":
    SameHost().main(sys.argv)
