#!/usr/bin/env python3 # pylint: disable=C0103,C0302,W0612,W0621 """It's the route manager by CIDR lists. """ import base64 import inspect import ipaddress import json import logging import urllib.request from argparse import ArgumentParser from datetime import datetime from os import path, sep, makedirs, remove, rmdir, system, walk from shutil import copyfile from sys import platform from subprocess import Popen, PIPE from zipfile import ZipFile, BadZipFile class Parse: """Parser of configs, arguments, parameters. """ # pylint: disable=C0123 def __init__(self, parameters, block: str = None) -> None: """Object constructor. Args: parameters: dictionary as "key":"value" or ArgumentParser class object or string path to the file or string as "var1=val1;var2=val2". block (str, optional): name of target block from text. Defaults to None. """ self.path = '' self.data = {} if type(parameters) is dict: self._dict2dict(parameters) if type(parameters) is ArgumentParser: self._dict2dict(self.argv2dict(parameters)) if type(parameters) is str: if path.exists(parameters): self._dict2dict( self.strs2dict( self.conf2strs(parameters), block ) ) self.path = parameters else: self._dict2dict(self.strs2dict(parameters, block)) def __str__(self) -> str: """Overrides method for print(object). Returns: str: string with contents of the object's dictionary. """ string = '' for key, val in self.data.items(): string += str(type(val)) + ' ' + str(key) + ' = ' + str(val) + '\n' return string def _dict2dict(self, dictionary: dict) -> None: """Updates or adds dictionary data. Args: dictionary (dict): dictionary as "key":"value". """ self.data.update(dictionary) # pylint: disable=C0206 def expand(self, store: str = None) -> dict: """Expand dictionary "key":"name.conf" to dictionary "key":{subkey: subval}. Args: store (str, optional): path to directory with name.conf. Defaults to None. Returns: dict: expanded dictionary as "key":{subkey: subval}. """ for key in self.data: if store: config = store + sep + self.data[key] else: config = self.data[key] with open(config, encoding='UTF-8') as file: self.data[key] = Parse(file.read()).data return self.data @classmethod def argv2dict(cls, parser: ArgumentParser) -> dict: """Converts startup arguments to a dictionary. Args: parser (ArgumentParser): argparse.ArgumentParser class object. Returns: dict: dictionary as "key":"value". """ parser = ArgumentParser(add_help=False, parents=[parser]) return vars(parser.parse_args()) @classmethod def conf2strs(cls, config: str) -> str: """Builds a dictionary from a file containing parameters. Args: config (str): path to the config file. Returns: str: string as "var1=val1;\nvar2=val2;". """ with open(config, encoding='UTF-8') as file: raw = file.read() strs = '' for line in raw.splitlines(): if not line.lstrip().startswith('#'): strs += line + '\n' return strs @classmethod def strs2dict(cls, strings: str, blockname: str) -> dict: """Builds a dictionary from a strings containing parameters. Args: strings (str): string as "var1=val1;var2=val2;". blockname (str): name of target block from text. Returns: dict: dictionary as "key":"value". """ dictionary = {} if blockname: strings = cls.block(blockname, strings) for line in strings.replace('\n', ';').split(';'): if not line.lstrip().startswith('#') and "=" in line: dictionary[line.split('=')[0].strip()] = ( line.split('=')[1].strip().split(';')[0].strip() ) return dictionary @classmethod def str2bool(cls, value: str) -> bool: """Converts a string value to boolean. Args: value (str): string containing "true" or "false", "yes" or "no", "1" or "0". Returns: bool: bool True or False. """ return str(value).lower() in ("true", "yes", "1") @classmethod def block(cls, blockname: str, text: str) -> str: """Cuts a block of text between line [blockname] and line [next block] or EOF. Args: blockname (str): string in [] after which the block starts. text (str): string of text from which the block is needed. Returns: str: string of text between line [block name] and line [next block]. """ level = 1 save = False result = '' for line in text.splitlines(): if line.startswith('[') and blockname in line: level = line.count('[') save = True elif line.startswith('[') and '['*level in line: save = False elif save: result += line + '\n' return result class Connect: """Set of connection methods (functions) for various protocols. """ @staticmethod # pylint: disable=W0102, W0718 def http( url: str, method: str = 'GET', username: str = '', password: str = '', authtype: (str, type(None)) = None, contenttype: str = 'text/plain', contentdata: (str, bytes) = '', headers: dict = {}, logger_alias: str = inspect.stack()[0].function ) -> dict: """Handling HTTP request. Args: url (str): Handling HTTP request. method (str, optional): HTTP request method. Defaults to 'GET'. username (str, optional): username for url authentication. Defaults to ''. password (str, optional): password for url authentication. Defaults to ''. authtype (str, None, optional): digest|basic authentication type. Defaults to None. contenttype (str, optional): 'Content-Type' header. Defaults to 'text/plain'. contentdata (str, bytes, optional): content data. Defaults to ''. headers (dict, optional): additional headers. Defaults to {}. logger_alias (str, optional): sublogger name. Defaults to function or method name. Returns: dict: {'success':bool,'result':HTTP response or 'ERROR'}. """ local_logger = logging.getLogger(logger_alias) if Do.args_valid(locals(), Connect.http.__annotations__): if contentdata != '': headers['Content-Type'] = contenttype if isinstance(contentdata, str): contentdata = bytes(contentdata.encode('utf-8')) # Preparing authorization if authtype: pswd = urllib.request.HTTPPasswordMgrWithDefaultRealm() pswd.add_password(None, url, username, password) if authtype == 'basic': auth = urllib.request.HTTPBasicAuthHandler(pswd) token = base64.b64encode((username + ':' + password).encode()) headers['Authorization'] = 'Basic ' + token.decode('utf-8') if authtype == 'digest': auth = urllib.request.HTTPDigestAuthHandler(pswd) urllib.request.install_opener(urllib.request.build_opener(auth)) # Preparing request request = urllib.request.Request( url=url, data=contentdata, method=method ) for key, val in headers.items(): request.add_header(key, val) if len(contentdata) > 128: contentdata = contentdata[:64] + b' ... ' + contentdata[-64:] # Response local_logger.debug(msg='' + '\n' + 'uri: ' + url + '\n' + 'method: ' + method + '\n' + 'username: ' + username + '\n' + 'password: ' + password + '\n' + 'authtype: ' + str(authtype) + '\n' + 'headers: ' + json.dumps(headers, indent=2) + '\n' + 'content-data: ' + str(contentdata) ) try: response = urllib.request.urlopen(request).read() try: response = str(response.decode('utf-8')) except UnicodeDecodeError: pass return {"success": True, "result": response} except Exception as error: local_logger.debug(msg='error: ' + '\n' + str(error)) return {"success": False, "result": "ERROR"} class Route(Connect): """Handling route operations. """ def __init__(self, gateways: dict, db_root_path: str) -> None: self._gw = gateways self._db_root_path = db_root_path def do( self, action: str, imitate: bool = False, logger_alias: str = inspect.stack()[0].function ) -> None: """Add or delete route. Args: action (str): 'add' or 'delete' imitate (bool, optional): Only showing and counting commands without applying them. Defaults to False. logger_alias (str, optional): sublogger name. Defaults to function or method name. """ local_logger = logging.getLogger(logger_alias) if Do.args_valid(locals(), self.update_db_current.__annotations__): cidr_current = [] for root, dirs, files in walk(self._db_root_path + sep + 'current', topdown=False): for file in files: cidr_current.append(path.join(path.realpath(root), file)) route_counter = 0 apply_counter = 0 gways_counter = 0 files_counter = 0 commands_list = [] for gw, cidr_apply in self._gw.items(): gways_counter += 1 for cidr in cidr_apply: for cidr_file in cidr_current: if cidr in cidr_file: with open(cidr_file, mode='r', encoding='utf-8') as file: files_counter += 1 cidr_data = file.read() gw_type = gw.split('-')[0] gw_name = gw.split('-')[1] for route in cidr_data.splitlines(): route_counter += 1 if platform.startswith('win32'): pass else: command = ['ip', 'ro', action, route, gw_type, gw_name] commands_list.append({'cidr': cidr, 'command': command}) if action == 'delete': commands_list = reversed(commands_list) for command in commands_list: local_logger = logging.getLogger(command['cidr']) local_logger.info(msg=' '.join(command['command'])) if not imitate: if self.__cmd(command=command['command']) == 0: apply_counter += 1 local_logger = logging.getLogger(logger_alias) local_logger.info(msg="" + action + " " + str(apply_counter) + " route(s)" + " for " + str(gways_counter) + " gateway(s)" + " from " + str(files_counter) + " file(s)" + " with " + str(route_counter) + " route(s)" ) def __cmd(self, command: list, logger_alias: str = inspect.stack()[0].function) -> int: local_logger = logging.getLogger(logger_alias) if Do.args_valid(locals(), self.__cmd.__annotations__): with Popen(command, stdout=PIPE, stderr=PIPE) as proc: for line in proc.stdout: local_logger.info(msg=line.decode('utf-8')) for line in proc.stderr: local_logger.warning(msg=line.decode('utf-8')) return proc.returncode def update_db_current( self, logger_alias: str = inspect.stack()[0].function ) -> bool: """Update current CIDR file database from sources. Args: logger_alias (str, optional): sublogger name. Defaults to function or method name. Returns: bool: True - database updated, False - there are exceptions. """ local_logger = logging.getLogger(logger_alias) if Do.args_valid(locals(), self.update_db_current.__annotations__): cidr_sources = [] for root, dirs, files in walk(self._db_root_path + sep + 'sources', topdown=False): for file in files: cidr_sources.append(path.join(path.realpath(root), file)) try: for root, dirs, files in walk(self._db_root_path + sep + 'current', topdown=False): for file in files: remove(path.join(path.realpath(root), file)) for directory in dirs: rmdir(path.join(path.realpath(root), directory)) except OSError as error: local_logger.debug(msg='error: ' + '\n' + str(error)) return False for gw, cidr_current in self._gw.items(): for cidr in cidr_current: cidr_saved = False for src_file in cidr_sources: if cidr in src_file: dst_file = src_file.replace('sources', 'current' + sep + gw) try: makedirs(path.dirname(dst_file), exist_ok=True) copyfile(src=src_file, dst=dst_file) local_logger.info(msg=dst_file + ' saved') cidr_saved = True break except IOError as error: local_logger.debug(msg='error: ' + '\n' + str(error)) if not cidr_saved: local_logger.warning(msg=cidr + ' not saved') return True def update_db_sources( self, name: str, db_root_path: str, db_source_code: (str, type(None)) = None, download_token: (str, type(None)) = None, download_user: (str, type(None)) = None, download_pass: (str, type(None)) = None, force_download: bool = False, logger_alias: str = inspect.stack()[0].function ) -> bool: """Download and extract sources to CIDR file database. Args: name (str): 'google', 'amazon', 'atlassian', 'herrbischoff', 'ip2location'. db_root_path (str): database directory. db_source_code (str, None, optional): ip2location database code. Defaults to None. download_token (str, None, optional): ip2location download token. Defaults to None. force_download (bool, optional): download sources even it exists. Defaults to False. logger_alias (str, optional): sublogger name. Defaults to function or method name. Returns: bool: True - sources updated, False - there are exceptions. """ if Do.args_valid(locals(), self.update_db_sources.__annotations__): if name == '': pass elif name == 'google': return self.__update_source_google( db_root_path=db_root_path, force_download=force_download, logger_alias=logger_alias ) elif name == 'amazon': return self.__update_source_amazon( db_root_path=db_root_path, force_download=force_download, logger_alias=logger_alias ) elif name == 'atlassian': return self.__update_source_atlassian( db_root_path=db_root_path, force_download=force_download, logger_alias=logger_alias ) elif name == 'herrbischoff': return self.__update_source_herrbischoff( db_root_path=db_root_path, force_download=force_download, logger_alias=logger_alias ) elif name == 'ip2location': return self.__update_source_ip2location( db_root_path=db_root_path, db_source_code=db_source_code, download_token=download_token, force_download=force_download, logger_alias=logger_alias ) elif name == 'githmptoday': return self.__update_source_githmptoday( db_root_path=db_root_path, download_user=download_user, download_pass=download_pass, force_download=force_download, logger_alias=logger_alias ) def __download_db( self, url: str, dst: str, logger_alias: str = inspect.stack()[0].function ) -> bool: local_logger = logging.getLogger(logger_alias) if Do.args_valid(locals(), self.__download_db.__annotations__): try: makedirs(path.dirname(dst), exist_ok=True) response = self.http(url=url, method='GET') if response['success']: open_mode = 'w+' open_encoding = 'utf-8' if isinstance(response['result'], bytes): open_mode = 'wb+' open_encoding = None with open(dst, mode=open_mode, encoding=open_encoding) as file: file.write(response['result']) local_logger.info(msg=dst + ' saved') return True else: raise ConnectionError('downloading ' + url + ' failed') except ConnectionError as error: logging.warning(msg='' + str(error)) return False def __update_source_google( self, db_root_path: str, force_download: bool = False, logger_alias: str = inspect.stack()[0].function ) -> bool: local_logger = logging.getLogger(logger_alias) if Do.args_valid(locals(), self.__update_source_google.__annotations__): db_source_url = "https://www.gstatic.com/ipranges/goog.json" db_source_name = "google" db_source_root = db_root_path + sep + "sources" + sep + db_source_name db_source_file = db_source_root + sep + "goog.json" db_source_cidr_root = db_source_root + sep + "cidr" if not path.exists(db_source_file): force_download = True if force_download: if not self.__download_db( url=db_source_url, dst=db_source_file, logger_alias=logger_alias ): return False with open(db_source_file, mode='r', encoding='utf-8') as db_source_raw: db_source_data = json.loads(db_source_raw.read()) db_parsed_data_ipv4 = [] db_parsed_data_ipv6 = [] for item in db_source_data['prefixes']: if 'ipv4Prefix' in item: db_parsed_data_ipv4.append(item['ipv4Prefix']) if 'ipv6Prefix' in item: db_parsed_data_ipv6.append(item['ipv6Prefix']) makedirs(db_source_cidr_root + sep + "ipv4", exist_ok=True) db_source_cidr_ipv4_file = ("" + db_source_cidr_root + sep + "ipv4" + sep + db_source_name + ".cidr" ) with open(db_source_cidr_ipv4_file, mode='w+', encoding='utf-8') as cidr_dump: cidr_dump.write('\n'.join(db_parsed_data_ipv4)) local_logger.info(msg=db_source_cidr_ipv4_file + ' saved') makedirs(db_source_cidr_root + sep + "ipv6", exist_ok=True) db_source_cidr_ipv6_file = ("" + db_source_cidr_root + sep + "ipv6" + sep + db_source_name + ".cidr" ) with open(db_source_cidr_ipv6_file, mode='w+', encoding='utf-8') as cidr_dump: cidr_dump.write('\n'.join(db_parsed_data_ipv6)) local_logger.info(msg=db_source_cidr_ipv6_file + ' saved') return True return False def __update_source_amazon( self, db_root_path: str, force_download: bool = False, logger_alias: str = inspect.stack()[0].function ) -> bool: local_logger = logging.getLogger(logger_alias) if Do.args_valid(locals(), self.__update_source_amazon.__annotations__): db_source_url = "https://ip-ranges.amazonaws.com/ip-ranges.json" db_source_name = "amazon" db_source_root = db_root_path + sep + "sources" + sep + db_source_name db_source_file = db_source_root + sep + "ip-ranges.json" db_source_cidr_root = db_source_root + sep + "cidr" if not path.exists(db_source_file): force_download = True if force_download: if not self.__download_db( url=db_source_url, dst=db_source_file, logger_alias=logger_alias ): return False with open(db_source_file, mode='r', encoding='utf-8') as db_source_raw: db_source_data = json.loads(db_source_raw.read()) db_parsed_data_ipv4 = [] for item in db_source_data['prefixes']: db_parsed_data_ipv4.append(item['ip_prefix']) makedirs(db_source_cidr_root + sep + "ipv4", exist_ok=True) db_source_cidr_ipv4_file = ("" + db_source_cidr_root + sep + "ipv4" + sep + db_source_name + ".cidr" ) with open(db_source_cidr_ipv4_file, mode='w+', encoding='utf-8') as cidr_dump: cidr_dump.write('\n'.join(db_parsed_data_ipv4)) local_logger.info(msg=db_source_cidr_ipv4_file + ' saved') db_parsed_data_ipv6 = [] for item in db_source_data['ipv6_prefixes']: db_parsed_data_ipv6.append(item['ipv6_prefix']) makedirs(db_source_cidr_root + sep + "ipv6", exist_ok=True) db_source_cidr_ipv6_file = ("" + db_source_cidr_root + sep + "ipv6" + sep + db_source_name + ".cidr" ) with open(db_source_cidr_ipv6_file, mode='w+', encoding='utf-8') as cidr_dump: cidr_dump.write('\n'.join(db_parsed_data_ipv6)) local_logger.info(msg=db_source_cidr_ipv6_file + ' saved') return True return False def __update_source_atlassian( self, db_root_path: str, force_download: bool = False, logger_alias: str = inspect.stack()[0].function ) -> bool: local_logger = logging.getLogger(logger_alias) if Do.args_valid(locals(), self.__update_source_atlassian.__annotations__): db_source_url = "https://ip-ranges.atlassian.com" db_source_name = "atlassian" db_source_root = db_root_path + sep + "sources" + sep + db_source_name db_source_file = db_source_root + sep + "ip-ranges.json" db_source_cidr_root = db_source_root + sep + "cidr" if not path.exists(db_source_file): force_download = True if force_download: if not self.__download_db( url=db_source_url, dst=db_source_file, logger_alias=logger_alias ): return False with open(db_source_file, mode='r', encoding='utf-8') as db_source_raw: db_source_data = json.loads(db_source_raw.read()) db_parsed_data_ipv4 = [] db_parsed_data_ipv6 = [] for item in db_source_data['items']: if not ":" in item['cidr']: db_parsed_data_ipv4.append(item['cidr']) else: db_parsed_data_ipv6.append(item['cidr']) makedirs(db_source_cidr_root + sep + "ipv4", exist_ok=True) db_source_cidr_ipv4_file = ("" + db_source_cidr_root + sep + "ipv4" + sep + db_source_name + ".cidr" ) with open(db_source_cidr_ipv4_file, mode='w+', encoding='utf-8') as cidr_dump: cidr_dump.write('\n'.join(db_parsed_data_ipv4)) local_logger.info(msg=db_source_cidr_ipv4_file + ' saved') makedirs(db_source_cidr_root + sep + "ipv6", exist_ok=True) db_source_cidr_ipv6_file = ("" + db_source_cidr_root + sep + "ipv6" + sep + db_source_name + ".cidr" ) with open(db_source_cidr_ipv6_file, mode='w+', encoding='utf-8') as cidr_dump: cidr_dump.write('\n'.join(db_parsed_data_ipv6)) local_logger.info(msg=db_source_cidr_ipv6_file + ' saved') return True return False def __update_source_herrbischoff( self, db_root_path: str, force_download: bool = False, logger_alias: str = inspect.stack()[0].function ) -> bool: local_logger = logging.getLogger(logger_alias) if Do.args_valid(locals(), self.__update_source_herrbischoff.__annotations__): db_source_url = ("" + "https://github.com/herrbischoff/" + "country-ip-blocks/archive/refs/heads/master.zip" ) db_source_name = "herrbischoff" db_source_root = db_root_path + sep + "sources" + sep + db_source_name db_source_file = db_source_root + sep + "country-ip-blocks-master.zip" db_source_cidr_root = db_source_root + sep + "cidr" if not path.exists(db_source_file): force_download = True if force_download: if not self.__download_db( url=db_source_url, dst=db_source_file, logger_alias=logger_alias ): return False try: with ZipFile(db_source_file, mode='r') as db_source_file_zip: makedirs(db_source_cidr_root + sep + "ipv4", exist_ok=True) makedirs(db_source_cidr_root + sep + "ipv6", exist_ok=True) for file in db_source_file_zip.infolist(): if ( "country-ip-blocks-master" + sep + "ipv4" in file.filename and not file.is_dir() ): country_data = db_source_file_zip.read(file.filename) country_file = ("" + db_source_cidr_root + sep + "ipv4" + sep + path.basename(file.filename) ) with open(country_file, mode='wb') as country_dump: country_dump.write(country_data) local_logger.info(msg=country_file + ' saved') if ( "country-ip-blocks-master" + sep + "ipv6" in file.filename and not file.is_dir() ): country_data = db_source_file_zip.read(file.filename) country_file = ("" + db_source_cidr_root + sep + "ipv6" + sep + path.basename(file.filename) ) with open(country_file, mode='wb') as country_dump: country_dump.write(country_data) local_logger.info(msg=country_file + ' saved') return True except BadZipFile as error: local_logger.debug(msg='error: ' + '\n' + str(error)) local_logger.warning(msg=db_source_file + ' corrupted and deleted') remove(db_source_file) return False def __update_source_ip2location( self, db_root_path: str, db_source_code: str, download_token: str, force_download: bool = False, logger_alias: str = inspect.stack()[0].function ) -> bool: local_logger = logging.getLogger(logger_alias) if Do.args_valid(locals(), self.__update_source_ip2location.__annotations__): DB_IP2L_BASE = { "DB1LITECSV": "IP2LOCATION-LITE-DB1.CSV", "DB1LITECSVIPV6": "IP2LOCATION-LITE-DB1.IPV6.CSV", "DB1": "IPCountry.csv", "DB1IPV6": "IPV6-COUNTRY.CSV", "DB1CIDR": "IP2LOCATION-IP-COUNTRY.CIDR.CSV", "DB1CIDRIPV6": "IP2LOCATION-IPV6-COUNTRY.CIDR.CSV" } db_source_url = ('' + "https://www.ip2location.com/download?token=" + download_token + "&file=" + db_source_code ) db_source_name = "ip2location" db_source_root = db_root_path + sep + "sources" + sep + db_source_name db_source_file = db_source_root + sep + db_source_code + ".ZIP" db_source_cidr_root = db_source_root + sep + "cidr" if not path.exists(db_source_file): force_download = True if force_download: if not self.__download_db( url=db_source_url, dst=db_source_file, logger_alias=logger_alias ): return False try: if db_source_code == 'DB1LITECSV' or db_source_code == 'DB1': with ZipFile(db_source_file, mode='r') as db_source_file_zip: with db_source_file_zip.open(DB_IP2L_BASE[db_source_code], mode='r' ) as db_source_raw: db_source_data = db_source_raw.read().decode('utf-8') db_parsed_data = {} for line in db_source_data.splitlines(): f_ipnum = int(line.split(',')[0].replace('"', '')) l_ipnum = int(line.split(',')[1].replace('"', '')) country_code = line.split(',')[2].replace('"', '').lower() country_name = line.split(',')[3].replace('"', '') f_ipstr = ('' + str(int(f_ipnum / 16777216) % 256) + "." + str(int(f_ipnum / 65536) % 256) + "." + str(int(f_ipnum / 256) % 256) + "." + str(f_ipnum % 256) ) l_ipstr = ('' + str(int(l_ipnum / 16777216) % 256) + "." + str(int(l_ipnum / 65536) % 256) + "." + str(int(l_ipnum / 256) % 256) + "." + str(l_ipnum % 256) ) f_ipadr = ipaddress.IPv4Address(f_ipstr) l_ipadr = ipaddress.IPv4Address(l_ipstr) for cidr in list( ipaddress.summarize_address_range(f_ipadr, l_ipadr) ): country_cidr = cidr.exploded if country_code not in db_parsed_data: db_parsed_data[country_code] = [] db_parsed_data[country_code].append(country_cidr) makedirs(db_source_cidr_root + sep + "ipv4", exist_ok=True) for country_code, country_data in db_parsed_data.items(): country_file = ("" + db_source_cidr_root + sep + "ipv4" + sep + country_code + ".cidr" ) with open(country_file, mode='w+', encoding='utf-8' ) as country_dump: country_dump.write('\n'.join(country_data)) local_logger.info(msg=country_file + ' saved') return True elif db_source_code == 'DB1CIDR': with ZipFile(db_source_file, mode='r') as db_source_file_zip: with db_source_file_zip.open(DB_IP2L_BASE[db_source_code], mode='r' ) as db_source_raw: db_source_data = db_source_raw.read().decode('utf-8') db_parsed_data = {} for line in db_source_data.splitlines(): country_cidr = line.split(',')[0].replace('"', '') country_code = line.split(',')[1].replace('"', '').lower() country_name = line.split(',')[2].replace('"', '') if country_code not in db_parsed_data: db_parsed_data[country_code] = [] db_parsed_data[country_code].append(country_cidr) makedirs(db_source_cidr_root + sep + "ipv4", exist_ok=True) for country_code, country_data in db_parsed_data.items(): country_file = ("" + db_source_cidr_root + sep + "ipv4" + sep + country_code + ".cidr" ) with open( country_file, mode='w+', encoding='utf-8' ) as country_dump: country_dump.write('\n'.join(country_data)) local_logger.info(msg=country_file + ' saved') return True except BadZipFile as error: local_logger.debug(msg='error: ' + '\n' + str(error)) local_logger.warning(msg=db_source_file + ' corrupted and deleted') remove(db_source_file) return False def __update_source_githmptoday( self, db_root_path: str, download_user: str, download_pass: str, force_download: bool = False, logger_alias: str = inspect.stack()[0].function ) -> bool: local_logger = logging.getLogger(logger_alias) if Do.args_valid(locals(), self.__update_source_githmptoday.__annotations__): if not path.exists(db_root_path + sep + 'sources'): force_download = True if force_download: db_temp_path = path.dirname(db_root_path) + sep + 'tmp' cmd_gitclone = ('' + "git clone " + "https://" + download_user + ":" + download_pass + "@git.hmp.today/pavel.muhortov/my_route.db.git " + db_temp_path ) ret_gitclone = system(cmd_gitclone) if ret_gitclone == 0: try: makedirs(db_root_path, exist_ok=True) for root, dirs, files in walk(db_temp_path, topdown=False): for file in files: src_file = path.join(path.realpath(root), file) dst_file = src_file.replace(db_temp_path, db_root_path) makedirs(path.dirname(dst_file), exist_ok=True) copyfile(src=src_file, dst=dst_file) remove(path.join(path.realpath(root), file)) for directory in dirs: rmdir(path.join(path.realpath(root), directory)) rmdir(db_temp_path) return True except OSError as error: local_logger.debug(msg='error: ' + '\n' + str(error)) else: local_logger.warning(msg='' + 'git clone returned '+ str(ret_gitclone) + ' code. ' + 'Restart by interactive and check stdout.' ) return False class Do(): """Set of various methods (functions) for routine. """ @staticmethod def args_valid(arguments: dict, annotations: dict) -> bool: """Arguments type validating by annotations. Args: arguments (dict): 'locals()' immediately after starting the function. annotations (dict): function.name.__annotations__. Raises: TypeError: type of argument is not equal type in annotation. Returns: bool: True if argument types are valid. """ for var_name, var_type in annotations.items(): if not var_name == 'return': if not isinstance(arguments[var_name], var_type): raise TypeError("" + "type of '" + var_name + "' = " + str(arguments[var_name]) + " is not " + str(var_type) ) return True @staticmethod def checkroot() -> bool: # pylint: disable=C0415 """Crossplatform privileged rights checker. Returns: bool: True - if privileged rights, False - if not privileged rights """ if platform.startswith('linux') or platform.startswith('darwin'): from os import geteuid if geteuid() == 0: return True return False elif platform.startswith('win32'): import ctypes return ctypes.windll.shell32.IsUserAnAdmin() if __name__ == "__main__": time_start = datetime.now() args = ArgumentParser( prog='my-route', description='Route management by CIDR lists.', epilog='Dependencies: ' '- Python 3 (tested version 3.9.5), ' '- privileged rights, ' '- git ' ) args.add_argument('--config', type=str, default=path.splitext(__file__)[0] + '.conf', required=False, help='custom configuration file path' ) args.add_argument('-a', '--add', action='store_true', required=False, help='add routes specified by config') args.add_argument('-d', '--del', action='store_true', required=False, help='del routes specified by config') args.add_argument('-i', '--imitate', action='store_true', required=False, help='only showing commands without applying them') args.add_argument('-u', '--update', action='store_true', required=False, help='update cidr file db') args.add_argument('-f', '--force', action='store_true', required=False, help='force download sources for update') args = vars(args.parse_args()) db_root_path = ('' + path.dirname(path.realpath(__file__)) + sep + path.splitext(path.basename(__file__))[0] + '.db') log_level = 'INFO' log_root = path.dirname(path.realpath(__file__)) enable_gateway = {} enable_sources = {} if path.exists(args['config']): conf_common = Parse(parameters=args['config'], block='common') if 'db_root_path' in conf_common.data: db_root_path = conf_common.data['db_root_path'] if 'log_root' in conf_common.data: log_root = conf_common.data['log_root'] if 'log_level' in conf_common.data: if conf_common.data['log_level'] == 'DEBUG': log_level = logging.DEBUG elif conf_common.data['log_level'] == 'INFO': log_level = logging.INFO elif conf_common.data['log_level'] == 'WARNING': log_level = logging.WARNING elif conf_common.data['log_level'] == 'ERROR': log_level = logging.ERROR elif conf_common.data['log_level'] == 'CRITICAL': log_level = logging.CRITICAL conf_gateway = Parse(parameters=args['config'], block='enable-gateway') for key, value in conf_gateway.data.items(): if value == 'true': gateway_config = Parse( parameters=args['config'], block=key ) enable_gateway[key] = [] for cidr, enable in gateway_config.data.items(): if enable == 'true': enable_gateway[key].append(cidr) conf_sources = Parse(parameters=args['config'], block='enable-sources') for key, value in conf_sources.data.items(): if value == 'true': enable_sources[key] = { 'enable': value, 'download_token': None, 'db_source_code': None, 'download_user': None, 'download_pass': None } if 'ip2location' in enable_sources: enable_sources['ip2location']['download_token'] = ( conf_sources.data['ip2l_download_token'] ) enable_sources['ip2location']['db_source_code'] = ( conf_sources.data['ip2l_database_code'] ) if 'githmptoday' in enable_sources: enable_sources['githmptoday']['download_user'] = ( conf_sources.data['githmptoday_user'] ) enable_sources['githmptoday']['download_pass'] = ( conf_sources.data['githmptoday_pass'] ) logging.basicConfig( format='%(asctime)s %(levelname)s: %(name)s: %(message)s', datefmt='%Y-%m-%d_%H.%M.%S', handlers=[ logging.FileHandler( filename=log_root + sep + path.splitext(path.basename(__file__))[0] + '.log', mode='a' ), logging.StreamHandler() ], level=log_level ) if Do.checkroot(): ro = Route(gateways=enable_gateway, db_root_path=db_root_path) if args['update']: for key, value in enable_sources.items(): ro.update_db_sources( name=key, db_root_path=db_root_path, db_source_code=enable_sources[key]['db_source_code'], download_token=enable_sources[key]['download_token'], download_user=enable_sources[key]['download_user'], download_pass=enable_sources[key]['download_pass'], force_download=args['force'], logger_alias='update sources ' + key ) ro.update_db_current(logger_alias='update current') elif args['add']: ro.do(action='add', imitate=args['imitate']) elif args['del']: ro.do(action='delete', imitate=args['imitate']) else: logging.info(msg='No start arguments selected. Exit.') else: logging.warning(msg='Restart this as root!') time_execute = datetime.now() - time_start logging.info(msg='execution time is ' + str(time_execute) + '. Exit.')