#!/usr/bin/python3

# HTTP Host header for SSH through use of ProxyCommand
# Copyright (C) 2016-2025  Kasper Dupont

# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation version 3.

# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.

# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

import os
import socket
import sys
import threading

def backend_connect(addresses):
    for address in addresses:
        sock = socket.socket(*address[:3])
        try:
            sock.connect(address[4])
        except socket.error as e:
            print(type(e), e)
        else:
            return sock

    return None

def encode_int(v):
    return v.to_bytes(4, 'big')

def encode_string(s):
    return encode_int(len(s)) + s

def get_server_banner(sock, hostname):
    payload = b'\x02\x00\x00\x00\x0B' + encode_string(
        b'SNI ProxyCommand preliminary connection'
        ) + encode_string(b'')

    padding = (': \r\nHost: %s\r\n\r\n' % hostname).encode('iso-8859-1')

    while (1 + len(payload) + len(padding)) % 8 != 4:
        padding += bytes(1)

    sock.send(b'SSH-2.0-SNI-ProxyCommand / HTTP/1.0\r\n' +
              encode_string(bytes([len(padding)]) + payload + padding))
    data = b''
    while b'\n' not  in data:
        data += sock.recv(4096)
    return data.split(b'\n')[0] + b'\n'

def copy_data(src):
    while True:
        data = src.recv(4096)
        if not data:
            os.dup2(2, 1)
            return
        os.write(1, data)

def packet_length(packet):
    if len(packet) < 4:
        return 4
    return int.from_bytes(packet[:4], 'big') + 4

def bytes_needed(data):
    banner, packet = data.split(b'\n', 1)
    return max(packet_length(packet) - len(packet), 0)

def mangle(data, hostname):
    banner, packet = data.split(b'\n', 1)
    l = packet_length(packet)
    trail = packet[l:]
    packet = packet[:l]

    padl = packet[4]
    packet_data = packet[5:-padl]

    new_padding = ('\r\nHost: %s\r\n\r\n' % hostname).encode('iso-8859-1')
    while (len(new_padding) % 8) != (padl % 8):
        new_padding += bytes(1)

    padl = len(new_padding)
    assert padl < 256

    packet_payload = bytes([padl]) + packet_data + new_padding

    return banner + b'\n' + encode_string(packet_payload) + trail

def main(argv):
    unused, hostname, port = argv

    addresses = socket.getaddrinfo(hostname, port, 0, socket.SOCK_STREAM)
    sock = backend_connect(addresses)
    server_banner = get_server_banner(sock, hostname)
    os.write(1, server_banner)

    data = os.read(0, 64)
    while not b'\n' in data:
        data += os.read(0, 64)
    while True:
        n = bytes_needed(data)
        if not n:
            break
        data += os.read(0, n)

    sock = backend_connect(addresses)
    sock.send(mangle(data, hostname))
    server_banner2 = sock.recv(len(server_banner))
    assert server_banner2 == server_banner

    threading.Thread(target = copy_data,
                     args = [sock]).start()

    while True:
        data = os.read(0, 4096)
        if not data:
            sock.shutdown(socket.SHUT_WR)
            return
        sock.send(data)

if __name__ == '__main__':
    main(sys.argv)
