Created
July 18, 2022 11:29
-
-
Save Ivlyth/d2b652e50ca2624f63b1c39a16cd751c to your computer and use it in GitHub Desktop.
python snippet for parse self-defined ssh forward config
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding:utf8 -*- | |
""" | |
Author : Myth | |
Date : 2021/1/21 | |
Email : email4myth at gmail.com | |
""" | |
from __future__ import unicode_literals | |
import ipaddress | |
import sys | |
global_ssh_user = 'root' | |
global_ssh_password = 'xxxxxx' | |
global_ssh_port = 22 | |
global_ssh_key_path = '/root/.ssh/id_rsa' | |
def parse_port(port_s): | |
port_s = port_s.strip() | |
if not port_s.isdigit(): | |
raise Exception("port should be integer, found: %s" % port_s) | |
try: | |
port = int(port_s) | |
except ValueError: | |
raise Exception("invalid integer: %s" % port_s) | |
if port <= 0 or port > 65535: | |
raise Exception("invalid port: %d, should in 0 < port <= 65535" % port) | |
return port | |
def parse_ip(ip_s): | |
ip_s = ip_s.strip() | |
try: | |
ipaddress.IPv4Address(ip_s) | |
except: | |
raise Exception("invalid ipv4 addr: %s" % ip_s) | |
return ip_s | |
def parse_ports(ports_s): | |
ports_s = ports_s.strip() | |
ports = [] | |
if ports_s.isdigit(): # only one port | |
ports.append(parse_port(ports_s)) | |
else: | |
port_start_s, _, port_end_s = ports_s.partition("-") | |
try: | |
port_start = parse_port(port_start_s) | |
except Exception as e: | |
raise Exception("invalid start port: %s" % e) | |
try: | |
port_end = parse_port(port_end_s) | |
except Exception as e: | |
raise Exception("invalid end port: %s" % e) | |
if port_end <= port_start: | |
raise Exception("invalid port range, end should large than the start: %s" % ports_s) | |
ports = range(port_start, port_end + 1) | |
return ports | |
def parse_ip_part(part_s): | |
part_s = part_s.strip() | |
if not part_s.isdigit(): | |
raise Exception("ip part should be integer, found: %s" % part_s) | |
try: | |
part = int(part_s) | |
except ValueError: | |
raise Exception("invalid integer: %s" % part_s) | |
if part < 0 or part > 255: | |
raise Exception("invalid ip part: %d, should in 0 <= port <= 255" % part) | |
return part | |
def parse_ip_parts(parts_s): | |
parts_s = parts_s.strip() | |
ip_parts = [] | |
if parts_s.isdigit(): # only one port | |
ip_parts.append(parse_ip_part(parts_s)) | |
else: | |
part_start_s, _, part_end_s = parts_s.partition("-") | |
try: | |
part_start = parse_ip_part(part_start_s) | |
except Exception as e: | |
raise Exception("invalid start ip part: %s" % e) | |
try: | |
part_end = parse_ip_part(part_end_s) | |
except Exception as e: | |
raise Exception("invalid end ip part: %s" % e) | |
if part_end <= part_start: | |
raise Exception("invalid ip part range, end should large than the start: %s" % parts_s) | |
ip_parts = range(part_start, part_end + 1) | |
return ip_parts | |
def parse_ips(ips_s): | |
ips_s = ips_s.strip() | |
parts = ips_s.split('.') | |
if len(parts) != 4: | |
raise Exception("invalid ip format: %s" % ips_s) | |
ips = [] | |
parts_1 = parse_ip_parts(parts[0]) | |
parts_2 = parse_ip_parts(parts[1]) | |
parts_3 = parse_ip_parts(parts[2]) | |
parts_4 = parse_ip_parts(parts[3]) | |
for part_1 in parts_1: | |
for part_2 in parts_2: | |
for part_3 in parts_3: | |
for part_4 in parts_4: | |
ip_s = '%s.%s.%s.%s' % (part_1, part_2, part_3, part_4) | |
ips.append(parse_ip(ip_s)) | |
return ips | |
class SSHServer(object): | |
def __init__(self, user, password, ip, port, key_path): | |
self.user = user | |
self.password = password | |
self.ip = ip | |
self.port = port | |
self.key_path = key_path | |
def __str__(self): | |
return 'SSHServer: %s:%s@%s:%s' % (self.user, self.password, self.ip, self.port) | |
class ForwardConfig(object): | |
def __init__(self, local_port, target, ssh_server): | |
self.local_port = local_port | |
self.target = target | |
self.ssh_server = ssh_server | |
def __str__(self): | |
return 'local port %d -> %s, through %s' % (self.local_port, self.target, self.ssh_server) | |
def __hash__(self): | |
return self.local_port | |
class Target(object): | |
def __init__(self, ip, port): | |
self.ip = ip | |
self.port = port | |
def __str__(self): | |
return 'Target: %s:%s' % (self.ip, self.port) | |
class Targets(object): | |
def __init__(self, ips_s, ports_s): | |
# 待访问的目标地址字符串 | |
self.ips_s = ips_s | |
# 待访问的目标端口字符串 | |
self.ports_s = ports_s | |
# ips 和 ports 的笛卡尔积组成了所有待访问目标 | |
# 待访问的 IP 列表 | |
self.ips = [] | |
# 待访问的端口列表 | |
self.ports = [] | |
self.targets = [] | |
def parse(self): | |
''' | |
return target list | |
:return: | |
''' | |
ips = parse_ips(self.ips_s) | |
ports = parse_ports(self.ports_s) | |
for ip in ips: | |
for port in ports: | |
self.targets.append(Target(ip, port)) | |
def __str__(self): | |
return 'Targets: %d ips, %d ports, %d targets' % (len(self.ips), len(self.ports), len(self.targets)) | |
class ConfigLine(object): | |
unique_local_ports = set() | |
def __init__(self, line): | |
self.line = line | |
self.local_ports = [] | |
self.targets = [] | |
self.ssh_server = None | |
def parse(self): | |
line = self.line | |
parts = line.split() | |
if len(parts) > 3: | |
raise Exception("found %d part in the line, expected most 3" % len(parts)) | |
if len(parts) < 2: | |
raise Exception("found %d part in the line, expected at least 2" % len(parts)) | |
local_ports_s = parts[0] | |
targets_s = parts[1] | |
if len(parts) == 3: | |
ssh_server_s = parts[2] | |
else: | |
ssh_server_s = '127.0.0.1' | |
self.local_ports_s = local_ports_s | |
self.targets_s = targets_s | |
self.ssh_server_s = ssh_server_s # aka. proxy server | |
self.parse_local_ports() | |
self.parse_targets() | |
if len(self.local_ports) != len(self.targets): | |
raise Exception("local ports num %d != targets num %d" % (len(self.local_ports), len(self.targets))) | |
self.parse_ssh_server() | |
def parse_local_ports(self): | |
self.local_ports = parse_ports(self.local_ports_s) | |
for lp in self.local_ports: | |
if lp in self.unique_local_ports: | |
raise Exception("duplicate local port detect: %d" % lp) | |
self.unique_local_ports.add(lp) | |
def parse_targets(self): | |
ips_s, _, ports_s = self.targets_s.partition(":") | |
if not ips_s: | |
raise Exception("target ip must be provide") | |
if not ports_s: | |
raise Exception("target port must be provide") | |
ts = Targets(ips_s, ports_s) | |
ts.parse() | |
self.targets = ts.targets | |
def parse_ssh_server(self): | |
user_pass_s, _, ip_port_s = self.ssh_server_s.rpartition('@') | |
user, _, password = user_pass_s.partition(':') | |
if not user: | |
user = global_ssh_user | |
if not password: | |
password = global_ssh_password | |
ip_s, _, port_s = ip_port_s.partition(':') | |
ip = parse_ip(ip_s) | |
if port_s: | |
port = parse_port(port_s) | |
else: | |
port = global_ssh_port | |
self.ssh_server = SSHServer(user, password, ip, port, global_ssh_key_path) | |
def __str__(self): | |
return 'Line: %s' % self.line | |
def main(): | |
forwards = [] | |
for index, line in enumerate(lines.splitlines()): | |
line_no = index + 1 | |
line = line.strip() | |
if not line: # ignore empty line | |
continue | |
if line.startswith('#'): # allow comment in the file | |
continue | |
try: | |
config = ConfigLine(line) | |
config.parse() | |
except Exception as e: | |
raise Exception("invalid line #%d: %s (line content: %s)" % (line_no, e, line)) | |
for local_port, target in zip(config.local_ports, config.targets): | |
forwards.append(ForwardConfig(local_port, target, config.ssh_server)) | |
# debug | |
for fwd in forwards: | |
print fwd | |
def rotate(array): | |
if len(array) <= 1: | |
return array | |
if __name__ == '__main__': | |
lines = '''\ | |
80 10.0.81.88:8080 | |
81-88 10.0.81.89:81-88 | |
# comment line, 88 is duplicated with above line | |
88-153 127.0.0.1-33:20222-20223 10.0.81.88 | |
''' | |
try: | |
main() | |
except Exception as e: | |
sys.exit(e.message) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment