my_route/my_route.py

1147 lines
48 KiB
Python
Executable File

#!/usr/bin/env python3
# pylint: disable=C0103,C0302,W0612,W0621
"""It's the route manager by CIDR lists.
"""
import base64
import inspect
import ipaddress
import json
import logging
import urllib.request
from argparse import ArgumentParser
from datetime import datetime
from os import path, sep, makedirs, remove, rmdir, system, walk
from shutil import copyfile
from sys import platform
from subprocess import Popen, PIPE
from zipfile import ZipFile, BadZipFile
class Parse:
"""Parser of configs, arguments, parameters.
"""
# pylint: disable=C0123
def __init__(self, parameters, block: str = None) -> None:
"""Object constructor.
Args:
parameters: dictionary as "key":"value" or
ArgumentParser class object or
string path to the file or
string as "var1=val1;var2=val2".
block (str, optional): name of target block from text. Defaults to None.
"""
self.path = ''
self.data = {}
if type(parameters) is dict:
self._dict2dict(parameters)
if type(parameters) is ArgumentParser:
self._dict2dict(self.argv2dict(parameters))
if type(parameters) is str:
if path.exists(parameters):
self._dict2dict(
self.strs2dict(
self.conf2strs(parameters),
block
)
)
self.path = parameters
else:
self._dict2dict(self.strs2dict(parameters, block))
def __str__(self) -> str:
"""Overrides method for print(object).
Returns:
str: string with contents of the object's dictionary.
"""
string = ''
for key, val in self.data.items():
string += str(type(val)) + ' ' + str(key) + ' = ' + str(val) + '\n'
return string
def _dict2dict(self, dictionary: dict) -> None:
"""Updates or adds dictionary data.
Args:
dictionary (dict): dictionary as "key":"value".
"""
self.data.update(dictionary)
# pylint: disable=C0206
def expand(self, store: str = None) -> dict:
"""Expand dictionary "key":"name.conf" to dictionary "key":{subkey: subval}.
Args:
store (str, optional): path to directory with name.conf. Defaults to None.
Returns:
dict: expanded dictionary as "key":{subkey: subval}.
"""
for key in self.data:
if store:
config = store + sep + self.data[key]
else:
config = self.data[key]
with open(config, encoding='UTF-8') as file:
self.data[key] = Parse(file.read()).data
return self.data
@classmethod
def argv2dict(cls, parser: ArgumentParser) -> dict:
"""Converts startup arguments to a dictionary.
Args:
parser (ArgumentParser): argparse.ArgumentParser class object.
Returns:
dict: dictionary as "key":"value".
"""
parser = ArgumentParser(add_help=False, parents=[parser])
return vars(parser.parse_args())
@classmethod
def conf2strs(cls, config: str) -> str:
"""Builds a dictionary from a file containing parameters.
Args:
config (str): path to the config file.
Returns:
str: string as "var1=val1;\nvar2=val2;".
"""
with open(config, encoding='UTF-8') as file:
raw = file.read()
strs = ''
for line in raw.splitlines():
if not line.lstrip().startswith('#'):
strs += line + '\n'
return strs
@classmethod
def strs2dict(cls, strings: str, blockname: str) -> dict:
"""Builds a dictionary from a strings containing parameters.
Args:
strings (str): string as "var1=val1;var2=val2;".
blockname (str): name of target block from text.
Returns:
dict: dictionary as "key":"value".
"""
dictionary = {}
if blockname:
strings = cls.block(blockname, strings)
for line in strings.replace('\n', ';').split(';'):
if not line.lstrip().startswith('#') and "=" in line:
dictionary[line.split('=')[0].strip()] = (
line.split('=')[1].strip().split(';')[0].strip()
)
return dictionary
@classmethod
def str2bool(cls, value: str) -> bool:
"""Converts a string value to boolean.
Args:
value (str): string containing "true" or "false", "yes" or "no", "1" or "0".
Returns:
bool: bool True or False.
"""
return str(value).lower() in ("true", "yes", "1")
@classmethod
def block(cls, blockname: str, text: str) -> str:
"""Cuts a block of text between line [blockname] and line [next block] or EOF.
Args:
blockname (str): string in [] after which the block starts.
text (str): string of text from which the block is needed.
Returns:
str: string of text between line [block name] and line [next block].
"""
level = 1
save = False
result = ''
for line in text.splitlines():
if line.startswith('[') and blockname in line:
level = line.count('[')
save = True
elif line.startswith('[') and '['*level in line:
save = False
elif save:
result += line + '\n'
return result
class Connect:
"""Set of connection methods (functions) for various protocols.
"""
@staticmethod
# pylint: disable=W0102, W0718
def http(
url: str,
method: str = 'GET',
username: str = '',
password: str = '',
authtype: (str, type(None)) = None,
contenttype: str = 'text/plain',
contentdata: (str, bytes) = '',
headers: dict = {},
logger_alias: str = inspect.stack()[0].function
) -> dict:
"""Handling HTTP request.
Args:
url (str): Handling HTTP request.
method (str, optional): HTTP request method. Defaults to 'GET'.
username (str, optional): username for url authentication. Defaults to ''.
password (str, optional): password for url authentication. Defaults to ''.
authtype (str, None, optional): digest|basic authentication type. Defaults to None.
contenttype (str, optional): 'Content-Type' header. Defaults to 'text/plain'.
contentdata (str, bytes, optional): content data. Defaults to ''.
headers (dict, optional): additional headers. Defaults to {}.
logger_alias (str, optional): sublogger name. Defaults to function or method name.
Returns:
dict: {'success':bool,'result':HTTP response or 'ERROR'}.
"""
local_logger = logging.getLogger(logger_alias)
if Do.args_valid(locals(), Connect.http.__annotations__):
if contentdata != '':
headers['Content-Type'] = contenttype
if isinstance(contentdata, str):
contentdata = bytes(contentdata.encode('utf-8'))
# Preparing authorization
if authtype:
pswd = urllib.request.HTTPPasswordMgrWithDefaultRealm()
pswd.add_password(None, url, username, password)
if authtype == 'basic':
auth = urllib.request.HTTPBasicAuthHandler(pswd)
token = base64.b64encode((username + ':' + password).encode())
headers['Authorization'] = 'Basic ' + token.decode('utf-8')
if authtype == 'digest':
auth = urllib.request.HTTPDigestAuthHandler(pswd)
urllib.request.install_opener(urllib.request.build_opener(auth))
# Preparing request
request = urllib.request.Request(
url=url,
data=contentdata,
method=method
)
for key, val in headers.items():
request.add_header(key, val)
if len(contentdata) > 128:
contentdata = contentdata[:64] + b' ... ' + contentdata[-64:]
# Response
local_logger.debug(msg=''
+ '\n' + 'uri: ' + url
+ '\n' + 'method: ' + method
+ '\n' + 'username: ' + username
+ '\n' + 'password: ' + password
+ '\n' + 'authtype: ' + str(authtype)
+ '\n' + 'headers: ' + json.dumps(headers, indent=2)
+ '\n' + 'content-data: ' + str(contentdata)
)
try:
response = urllib.request.urlopen(request).read()
try:
response = str(response.decode('utf-8'))
except UnicodeDecodeError:
pass
return {"success": True, "result": response}
except Exception as error:
local_logger.debug(msg='error: ' + '\n' + str(error))
return {"success": False, "result": "ERROR"}
class Route(Connect):
"""Handling route operations.
"""
def __init__(self, gateways: dict, db_root_path: str) -> None:
self._gw = gateways
self._db_root_path = db_root_path
def do(
self,
action: str,
imitate: bool = False,
logger_alias: str = inspect.stack()[0].function
) -> None:
"""Add or delete route.
Args:
action (str): 'add' or 'delete'
imitate (bool, optional): Only showing and counting commands without applying them.
Defaults to False.
logger_alias (str, optional): sublogger name.
Defaults to function or method name.
"""
local_logger = logging.getLogger(logger_alias)
if Do.args_valid(locals(), self.update_db_current.__annotations__):
cidr_current = []
for root, dirs, files in walk(self._db_root_path + sep + 'current', topdown=False):
for file in files:
cidr_current.append(path.join(path.realpath(root), file))
route_counter = 0
apply_counter = 0
gways_counter = 0
files_counter = 0
commands_list = []
for gw, cidr_apply in self._gw.items():
gways_counter += 1
for cidr in cidr_apply:
for cidr_file in cidr_current:
if cidr in cidr_file:
with open(cidr_file, mode='r', encoding='utf-8') as file:
files_counter += 1
cidr_data = file.read()
gw_type = gw.split('-')[0]
gw_name = gw.split('-')[1]
for route in cidr_data.splitlines():
route = route.split('#')[0].strip()
route_counter += 1
if platform.startswith('win32'):
pass
else:
command = ['ip', 'ro', action, route, gw_type, gw_name]
commands_list.append({'cidr': cidr, 'command': command})
if action == 'delete':
commands_list = reversed(commands_list)
for command in commands_list:
local_logger = logging.getLogger(command['cidr'])
local_logger.info(msg=' '.join(command['command']))
if not imitate:
if self.__cmd(command=command['command']) == 0:
apply_counter += 1
local_logger = logging.getLogger(logger_alias)
local_logger.info(msg=""
+ action + " " + str(apply_counter) + " route(s)"
+ " for " + str(gways_counter) + " gateway(s)"
+ " from " + str(files_counter) + " file(s)"
+ " with " + str(route_counter) + " route(s)"
)
def __cmd(self, command: list, logger_alias: str = inspect.stack()[0].function) -> int:
local_logger = logging.getLogger(logger_alias)
if Do.args_valid(locals(), self.__cmd.__annotations__):
with Popen(command, stdout=PIPE, stderr=PIPE) as proc:
for line in proc.stdout:
local_logger.info(msg=line.decode('utf-8'))
for line in proc.stderr:
local_logger.warning(msg=line.decode('utf-8'))
return proc.returncode
def update_db_current(
self,
logger_alias: str = inspect.stack()[0].function
) -> bool:
"""Update current CIDR file database from sources.
Args:
logger_alias (str, optional): sublogger name. Defaults to function or method name.
Returns:
bool: True - database updated, False - there are exceptions.
"""
local_logger = logging.getLogger(logger_alias)
if Do.args_valid(locals(), self.update_db_current.__annotations__):
cidr_sources = []
for root, dirs, files in walk(self._db_root_path + sep + 'sources', topdown=False):
for file in files:
cidr_sources.append(path.join(path.realpath(root), file))
try:
for root, dirs, files in walk(self._db_root_path + sep + 'current', topdown=False):
for file in files:
remove(path.join(path.realpath(root), file))
for directory in dirs:
rmdir(path.join(path.realpath(root), directory))
except OSError as error:
local_logger.debug(msg='error: ' + '\n' + str(error))
return False
for gw, cidr_current in self._gw.items():
for cidr in cidr_current:
cidr_saved = False
for src_file in cidr_sources:
if cidr in src_file:
dst_file = src_file.replace('sources', 'current' + sep + gw)
try:
makedirs(path.dirname(dst_file), exist_ok=True)
copyfile(src=src_file, dst=dst_file)
local_logger.info(msg=dst_file + ' saved')
cidr_saved = True
break
except IOError as error:
local_logger.debug(msg='error: ' + '\n' + str(error))
if not cidr_saved:
local_logger.warning(msg=cidr + ' not saved')
return True
def update_db_sources(
self,
name: str,
db_root_path: str,
db_source_code: (str, type(None)) = None,
download_token: (str, type(None)) = None,
download_user: (str, type(None)) = None,
download_pass: (str, type(None)) = None,
force_download: bool = False,
logger_alias: str = inspect.stack()[0].function
) -> bool:
"""Download and extract sources to CIDR file database.
Args:
name (str): 'google', 'amazon', 'atlassian', 'herrbischoff', 'ip2location'.
db_root_path (str): database directory.
db_source_code (str, None, optional): ip2location database code. Defaults to None.
download_token (str, None, optional): ip2location download token. Defaults to None.
force_download (bool, optional): download sources even it exists. Defaults to False.
logger_alias (str, optional): sublogger name. Defaults to function or method name.
Returns:
bool: True - sources updated, False - there are exceptions.
"""
if Do.args_valid(locals(), self.update_db_sources.__annotations__):
if name == '':
pass
elif name == 'google':
return self.__update_source_google(
db_root_path=db_root_path,
force_download=force_download,
logger_alias=logger_alias
)
elif name == 'amazon':
return self.__update_source_amazon(
db_root_path=db_root_path,
force_download=force_download,
logger_alias=logger_alias
)
elif name == 'antifilter':
return self.__update_source_antifilter(
db_root_path=db_root_path,
force_download=force_download,
logger_alias=logger_alias
)
elif name == 'atlassian':
return self.__update_source_atlassian(
db_root_path=db_root_path,
force_download=force_download,
logger_alias=logger_alias
)
elif name == 'herrbischoff':
return self.__update_source_herrbischoff(
db_root_path=db_root_path,
force_download=force_download,
logger_alias=logger_alias
)
elif name == 'ip2location':
return self.__update_source_ip2location(
db_root_path=db_root_path,
db_source_code=db_source_code,
download_token=download_token,
force_download=force_download,
logger_alias=logger_alias
)
elif name == 'githmptoday':
return self.__update_source_githmptoday(
db_root_path=db_root_path,
download_user=download_user,
download_pass=download_pass,
force_download=force_download,
logger_alias=logger_alias
)
def __download_db(
self,
url: str,
dst: str,
logger_alias: str = inspect.stack()[0].function
) -> bool:
local_logger = logging.getLogger(logger_alias)
if Do.args_valid(locals(), self.__download_db.__annotations__):
try:
makedirs(path.dirname(dst), exist_ok=True)
response = self.http(url=url, method='GET')
if response['success']:
open_mode = 'w+'
open_encoding = 'utf-8'
if isinstance(response['result'], bytes):
open_mode = 'wb+'
open_encoding = None
with open(dst, mode=open_mode, encoding=open_encoding) as file:
file.write(response['result'])
local_logger.info(msg=dst + ' saved')
return True
else:
raise ConnectionError('downloading ' + url + ' failed')
except ConnectionError as error:
logging.warning(msg='' + str(error))
return False
def __update_source_google(
self,
db_root_path: str,
force_download: bool = False,
logger_alias: str = inspect.stack()[0].function
) -> bool:
local_logger = logging.getLogger(logger_alias)
if Do.args_valid(locals(), self.__update_source_google.__annotations__):
db_source_url = "https://www.gstatic.com/ipranges/goog.json"
db_source_name = "google"
db_source_root = db_root_path + sep + "sources" + sep + db_source_name
db_source_file = db_source_root + sep + "goog.json"
db_source_cidr_root = db_source_root + sep + "cidr"
if not path.exists(db_source_file):
force_download = True
if force_download:
if not self.__download_db(
url=db_source_url,
dst=db_source_file,
logger_alias=logger_alias
):
return False
with open(db_source_file, mode='r', encoding='utf-8') as db_source_raw:
db_source_data = json.loads(db_source_raw.read())
db_parsed_data_ipv4 = []
db_parsed_data_ipv6 = []
for item in db_source_data['prefixes']:
if 'ipv4Prefix' in item:
db_parsed_data_ipv4.append(item['ipv4Prefix'])
if 'ipv6Prefix' in item:
db_parsed_data_ipv6.append(item['ipv6Prefix'])
makedirs(db_source_cidr_root + sep + "ipv4", exist_ok=True)
db_source_cidr_ipv4_file = (""
+ db_source_cidr_root + sep
+ "ipv4" + sep
+ db_source_name + ".cidr"
)
with open(db_source_cidr_ipv4_file, mode='w+', encoding='utf-8') as cidr_dump:
cidr_dump.write('\n'.join(db_parsed_data_ipv4))
local_logger.info(msg=db_source_cidr_ipv4_file + ' saved')
makedirs(db_source_cidr_root + sep + "ipv6", exist_ok=True)
db_source_cidr_ipv6_file = (""
+ db_source_cidr_root + sep
+ "ipv6" + sep
+ db_source_name + ".cidr"
)
with open(db_source_cidr_ipv6_file, mode='w+', encoding='utf-8') as cidr_dump:
cidr_dump.write('\n'.join(db_parsed_data_ipv6))
local_logger.info(msg=db_source_cidr_ipv6_file + ' saved')
return True
return False
def __update_source_amazon(
self,
db_root_path: str,
force_download: bool = False,
logger_alias: str = inspect.stack()[0].function
) -> bool:
local_logger = logging.getLogger(logger_alias)
if Do.args_valid(locals(), self.__update_source_amazon.__annotations__):
db_source_url = "https://ip-ranges.amazonaws.com/ip-ranges.json"
db_source_name = "amazon"
db_source_root = db_root_path + sep + "sources" + sep + db_source_name
db_source_file = db_source_root + sep + "ip-ranges.json"
db_source_cidr_root = db_source_root + sep + "cidr"
if not path.exists(db_source_file):
force_download = True
if force_download:
if not self.__download_db(
url=db_source_url,
dst=db_source_file,
logger_alias=logger_alias
):
return False
with open(db_source_file, mode='r', encoding='utf-8') as db_source_raw:
db_source_data = json.loads(db_source_raw.read())
db_parsed_data_ipv4 = []
for item in db_source_data['prefixes']:
db_parsed_data_ipv4.append(item['ip_prefix'])
makedirs(db_source_cidr_root + sep + "ipv4", exist_ok=True)
db_source_cidr_ipv4_file = (""
+ db_source_cidr_root + sep
+ "ipv4" + sep
+ db_source_name + ".cidr"
)
with open(db_source_cidr_ipv4_file, mode='w+', encoding='utf-8') as cidr_dump:
cidr_dump.write('\n'.join(db_parsed_data_ipv4))
local_logger.info(msg=db_source_cidr_ipv4_file + ' saved')
db_parsed_data_ipv6 = []
for item in db_source_data['ipv6_prefixes']:
db_parsed_data_ipv6.append(item['ipv6_prefix'])
makedirs(db_source_cidr_root + sep + "ipv6", exist_ok=True)
db_source_cidr_ipv6_file = (""
+ db_source_cidr_root + sep
+ "ipv6" + sep
+ db_source_name + ".cidr"
)
with open(db_source_cidr_ipv6_file, mode='w+', encoding='utf-8') as cidr_dump:
cidr_dump.write('\n'.join(db_parsed_data_ipv6))
local_logger.info(msg=db_source_cidr_ipv6_file + ' saved')
return True
return False
def __update_source_atlassian(
self,
db_root_path: str,
force_download: bool = False,
logger_alias: str = inspect.stack()[0].function
) -> bool:
local_logger = logging.getLogger(logger_alias)
if Do.args_valid(locals(), self.__update_source_atlassian.__annotations__):
db_source_url = "https://ip-ranges.atlassian.com"
db_source_name = "atlassian"
db_source_root = db_root_path + sep + "sources" + sep + db_source_name
db_source_file = db_source_root + sep + "ip-ranges.json"
db_source_cidr_root = db_source_root + sep + "cidr"
if not path.exists(db_source_file):
force_download = True
if force_download:
if not self.__download_db(
url=db_source_url,
dst=db_source_file,
logger_alias=logger_alias
):
return False
with open(db_source_file, mode='r', encoding='utf-8') as db_source_raw:
db_source_data = json.loads(db_source_raw.read())
db_parsed_data_ipv4 = []
db_parsed_data_ipv6 = []
for item in db_source_data['items']:
if not ":" in item['cidr']:
db_parsed_data_ipv4.append(item['cidr'])
else:
db_parsed_data_ipv6.append(item['cidr'])
makedirs(db_source_cidr_root + sep + "ipv4", exist_ok=True)
db_source_cidr_ipv4_file = (""
+ db_source_cidr_root + sep + "ipv4"
+ sep + db_source_name + ".cidr"
)
with open(db_source_cidr_ipv4_file, mode='w+', encoding='utf-8') as cidr_dump:
cidr_dump.write('\n'.join(db_parsed_data_ipv4))
local_logger.info(msg=db_source_cidr_ipv4_file + ' saved')
makedirs(db_source_cidr_root + sep + "ipv6", exist_ok=True)
db_source_cidr_ipv6_file = (""
+ db_source_cidr_root + sep + "ipv6"
+ sep + db_source_name + ".cidr"
)
with open(db_source_cidr_ipv6_file, mode='w+', encoding='utf-8') as cidr_dump:
cidr_dump.write('\n'.join(db_parsed_data_ipv6))
local_logger.info(msg=db_source_cidr_ipv6_file + ' saved')
return True
return False
def __update_source_antifilter(
self,
db_root_path: str,
force_download: bool = False,
logger_alias: str = inspect.stack()[0].function
) -> bool:
local_logger = logging.getLogger(logger_alias)
if Do.args_valid(locals(), self.__update_source_antifilter.__annotations__):
db_source_url = "https://antifilter.network/download/ipsmart.lst"
db_source_name = "antifilter"
db_source_root = db_root_path + sep + "sources" + sep + db_source_name
db_source_file = db_source_root + sep + "ipsmart.lst"
db_source_cidr_root = db_source_root + sep + "cidr"
if not path.exists(db_source_file):
force_download = True
if force_download:
if not self.__download_db(
url=db_source_url,
dst=db_source_file,
logger_alias=logger_alias
):
return False
with open(db_source_file, mode='r', encoding='utf-8') as db_source_raw:
db_source_data = db_source_raw.read().splitlines()
db_parsed_data_ipv4 = []
db_parsed_data_ipv6 = []
for item in db_source_data:
if not ":" in item:
db_parsed_data_ipv4.append(item)
else:
db_parsed_data_ipv6.append(item)
makedirs(db_source_cidr_root + sep + "ipv4", exist_ok=True)
db_source_cidr_ipv4_file = (""
+ db_source_cidr_root + sep + "ipv4"
+ sep + db_source_name + ".cidr"
)
with open(db_source_cidr_ipv4_file, mode='w+', encoding='utf-8') as cidr_dump:
cidr_dump.write('\n'.join(db_parsed_data_ipv4))
local_logger.info(msg=db_source_cidr_ipv4_file + ' saved')
makedirs(db_source_cidr_root + sep + "ipv6", exist_ok=True)
db_source_cidr_ipv6_file = (""
+ db_source_cidr_root + sep + "ipv6"
+ sep + db_source_name + ".cidr"
)
with open(db_source_cidr_ipv6_file, mode='w+', encoding='utf-8') as cidr_dump:
cidr_dump.write('\n'.join(db_parsed_data_ipv6))
local_logger.info(msg=db_source_cidr_ipv6_file + ' saved')
return True
return False
def __update_source_herrbischoff(
self,
db_root_path: str,
force_download: bool = False,
logger_alias: str = inspect.stack()[0].function
) -> bool:
local_logger = logging.getLogger(logger_alias)
if Do.args_valid(locals(), self.__update_source_herrbischoff.__annotations__):
db_source_url = (""
+ "https://github.com/herrbischoff/"
+ "country-ip-blocks/archive/refs/heads/master.zip"
)
db_source_name = "herrbischoff"
db_source_root = db_root_path + sep + "sources" + sep + db_source_name
db_source_file = db_source_root + sep + "country-ip-blocks-master.zip"
db_source_cidr_root = db_source_root + sep + "cidr"
if not path.exists(db_source_file):
force_download = True
if force_download:
if not self.__download_db(
url=db_source_url,
dst=db_source_file,
logger_alias=logger_alias
):
return False
try:
with ZipFile(db_source_file, mode='r') as db_source_file_zip:
makedirs(db_source_cidr_root + sep + "ipv4", exist_ok=True)
makedirs(db_source_cidr_root + sep + "ipv6", exist_ok=True)
for file in db_source_file_zip.infolist():
if (
"country-ip-blocks-master" + sep + "ipv4" in file.filename and
not file.is_dir()
):
country_data = db_source_file_zip.read(file.filename)
country_file = (""
+ db_source_cidr_root + sep
+ "ipv4" + sep + path.basename(file.filename)
)
with open(country_file, mode='wb') as country_dump:
country_dump.write(country_data)
local_logger.info(msg=country_file + ' saved')
if (
"country-ip-blocks-master" + sep + "ipv6" in file.filename and
not file.is_dir()
):
country_data = db_source_file_zip.read(file.filename)
country_file = (""
+ db_source_cidr_root + sep
+ "ipv6" + sep + path.basename(file.filename)
)
with open(country_file, mode='wb') as country_dump:
country_dump.write(country_data)
local_logger.info(msg=country_file + ' saved')
return True
except BadZipFile as error:
local_logger.debug(msg='error: ' + '\n' + str(error))
local_logger.warning(msg=db_source_file + ' corrupted and deleted')
remove(db_source_file)
return False
def __update_source_ip2location(
self,
db_root_path: str,
db_source_code: str,
download_token: str,
force_download: bool = False,
logger_alias: str = inspect.stack()[0].function
) -> bool:
local_logger = logging.getLogger(logger_alias)
if Do.args_valid(locals(), self.__update_source_ip2location.__annotations__):
DB_IP2L_BASE = {
"DB1LITECSV": "IP2LOCATION-LITE-DB1.CSV",
"DB1LITECSVIPV6": "IP2LOCATION-LITE-DB1.IPV6.CSV",
"DB1": "IPCountry.csv",
"DB1IPV6": "IPV6-COUNTRY.CSV",
"DB1CIDR": "IP2LOCATION-IP-COUNTRY.CIDR.CSV",
"DB1CIDRIPV6": "IP2LOCATION-IPV6-COUNTRY.CIDR.CSV"
}
db_source_url = (''
+ "https://www.ip2location.com/download?token=" + download_token
+ "&file=" + db_source_code
)
db_source_name = "ip2location"
db_source_root = db_root_path + sep + "sources" + sep + db_source_name
db_source_file = db_source_root + sep + db_source_code + ".ZIP"
db_source_cidr_root = db_source_root + sep + "cidr"
if not path.exists(db_source_file):
force_download = True
if force_download:
if not self.__download_db(
url=db_source_url,
dst=db_source_file,
logger_alias=logger_alias
):
return False
try:
if db_source_code == 'DB1LITECSV' or db_source_code == 'DB1':
with ZipFile(db_source_file, mode='r') as db_source_file_zip:
with db_source_file_zip.open(DB_IP2L_BASE[db_source_code], mode='r'
) as db_source_raw:
db_source_data = db_source_raw.read().decode('utf-8')
db_parsed_data = {}
for line in db_source_data.splitlines():
f_ipnum = int(line.split(',')[0].replace('"', ''))
l_ipnum = int(line.split(',')[1].replace('"', ''))
country_code = line.split(',')[2].replace('"', '').lower()
country_name = line.split(',')[3].replace('"', '')
f_ipstr = (''
+ str(int(f_ipnum / 16777216) % 256) + "."
+ str(int(f_ipnum / 65536) % 256) + "."
+ str(int(f_ipnum / 256) % 256) + "."
+ str(f_ipnum % 256)
)
l_ipstr = (''
+ str(int(l_ipnum / 16777216) % 256) + "."
+ str(int(l_ipnum / 65536) % 256) + "."
+ str(int(l_ipnum / 256) % 256) + "."
+ str(l_ipnum % 256)
)
f_ipadr = ipaddress.IPv4Address(f_ipstr)
l_ipadr = ipaddress.IPv4Address(l_ipstr)
for cidr in list(
ipaddress.summarize_address_range(f_ipadr, l_ipadr)
):
country_cidr = cidr.exploded
if country_code not in db_parsed_data:
db_parsed_data[country_code] = []
db_parsed_data[country_code].append(country_cidr)
makedirs(db_source_cidr_root + sep + "ipv4", exist_ok=True)
for country_code, country_data in db_parsed_data.items():
country_file = (""
+ db_source_cidr_root + sep
+ "ipv4" + sep + country_code + ".cidr"
)
with open(country_file, mode='w+', encoding='utf-8'
) as country_dump:
country_dump.write('\n'.join(country_data))
local_logger.info(msg=country_file + ' saved')
return True
elif db_source_code == 'DB1CIDR':
with ZipFile(db_source_file, mode='r') as db_source_file_zip:
with db_source_file_zip.open(DB_IP2L_BASE[db_source_code], mode='r'
) as db_source_raw:
db_source_data = db_source_raw.read().decode('utf-8')
db_parsed_data = {}
for line in db_source_data.splitlines():
country_cidr = line.split(',')[0].replace('"', '')
country_code = line.split(',')[1].replace('"', '').lower()
country_name = line.split(',')[2].replace('"', '')
if country_code not in db_parsed_data:
db_parsed_data[country_code] = []
db_parsed_data[country_code].append(country_cidr)
makedirs(db_source_cidr_root + sep + "ipv4", exist_ok=True)
for country_code, country_data in db_parsed_data.items():
country_file = (""
+ db_source_cidr_root + sep
+ "ipv4" + sep + country_code + ".cidr"
)
with open(
country_file, mode='w+', encoding='utf-8'
) as country_dump:
country_dump.write('\n'.join(country_data))
local_logger.info(msg=country_file + ' saved')
return True
except BadZipFile as error:
local_logger.debug(msg='error: ' + '\n' + str(error))
local_logger.warning(msg=db_source_file + ' corrupted and deleted')
remove(db_source_file)
return False
def __update_source_githmptoday(
self,
db_root_path: str,
download_user: str,
download_pass: str,
force_download: bool = False,
logger_alias: str = inspect.stack()[0].function
) -> bool:
local_logger = logging.getLogger(logger_alias)
if Do.args_valid(locals(), self.__update_source_githmptoday.__annotations__):
if not path.exists(db_root_path + sep + 'sources'):
force_download = True
if force_download:
db_temp_path = path.dirname(db_root_path) + sep + 'tmp'
cmd_gitclone = (''
+ "git clone " + "https://" + download_user + ":" + download_pass
+ "@git.hmp.today/pavel.muhortov/my_route.db.git " + db_temp_path
)
ret_gitclone = system(cmd_gitclone)
if ret_gitclone == 0:
try:
makedirs(db_root_path, exist_ok=True)
for root, dirs, files in walk(db_temp_path, topdown=False):
for file in files:
src_file = path.join(path.realpath(root), file)
dst_file = src_file.replace(db_temp_path, db_root_path)
makedirs(path.dirname(dst_file), exist_ok=True)
copyfile(src=src_file, dst=dst_file)
remove(path.join(path.realpath(root), file))
for directory in dirs:
rmdir(path.join(path.realpath(root), directory))
rmdir(db_temp_path)
return True
except OSError as error:
local_logger.debug(msg='error: ' + '\n' + str(error))
else:
local_logger.warning(msg=''
+ 'git clone returned '+ str(ret_gitclone) + ' code. '
+ 'Restart by interactive and check stdout.'
)
return False
class Do():
"""Set of various methods (functions) for routine.
"""
@staticmethod
def args_valid(arguments: dict, annotations: dict) -> bool:
"""Arguments type validating by annotations.
Args:
arguments (dict): 'locals()' immediately after starting the function.
annotations (dict): function.name.__annotations__.
Raises:
TypeError: type of argument is not equal type in annotation.
Returns:
bool: True if argument types are valid.
"""
for var_name, var_type in annotations.items():
if not var_name == 'return':
if not isinstance(arguments[var_name], var_type):
raise TypeError(""
+ "type of '"
+ var_name
+ "' = "
+ str(arguments[var_name])
+ " is not "
+ str(var_type)
)
return True
@staticmethod
def checkroot() -> bool:
# pylint: disable=C0415
"""Crossplatform privileged rights checker.
Returns:
bool: True - if privileged rights, False - if not privileged rights
"""
if platform.startswith('linux') or platform.startswith('darwin'):
from os import geteuid
if geteuid() == 0:
return True
return False
elif platform.startswith('win32'):
import ctypes
return ctypes.windll.shell32.IsUserAnAdmin()
if __name__ == "__main__":
time_start = datetime.now()
args = ArgumentParser(
prog='my-route',
description='Route management by CIDR lists.',
epilog='Dependencies: '
'- Python 3 (tested version 3.9.5), '
'- privileged rights, '
'- git '
)
args.add_argument('--config', type=str,
default=path.splitext(__file__)[0] + '.conf',
required=False,
help='custom configuration file path'
)
args.add_argument('-a', '--add', action='store_true', required=False,
help='add routes specified by config')
args.add_argument('-d', '--del', action='store_true', required=False,
help='del routes specified by config')
args.add_argument('-i', '--imitate', action='store_true', required=False,
help='only showing commands without applying them')
args.add_argument('-u', '--update', action='store_true', required=False,
help='update cidr file db')
args.add_argument('-f', '--force', action='store_true', required=False,
help='force download sources for update')
args = vars(args.parse_args())
db_root_path = (''
+ path.dirname(path.realpath(__file__)) + sep
+ path.splitext(path.basename(__file__))[0] + '.db')
log_level = 'INFO'
log_root = path.dirname(path.realpath(__file__))
enable_gateway = {}
enable_sources = {}
if path.exists(args['config']):
conf_common = Parse(parameters=args['config'], block='common')
if 'db_root_path' in conf_common.data:
db_root_path = conf_common.data['db_root_path']
if 'log_root' in conf_common.data:
log_root = conf_common.data['log_root']
if 'log_level' in conf_common.data:
if conf_common.data['log_level'] == 'DEBUG':
log_level = logging.DEBUG
elif conf_common.data['log_level'] == 'INFO':
log_level = logging.INFO
elif conf_common.data['log_level'] == 'WARNING':
log_level = logging.WARNING
elif conf_common.data['log_level'] == 'ERROR':
log_level = logging.ERROR
elif conf_common.data['log_level'] == 'CRITICAL':
log_level = logging.CRITICAL
conf_gateway = Parse(parameters=args['config'], block='enable-gateway')
for key, value in conf_gateway.data.items():
if value == 'true':
gateway_config = Parse(
parameters=args['config'],
block=key
)
enable_gateway[key] = []
for cidr, enable in gateway_config.data.items():
if enable == 'true':
enable_gateway[key].append(cidr)
conf_sources = Parse(parameters=args['config'], block='enable-sources')
for key, value in conf_sources.data.items():
if value == 'true':
enable_sources[key] = {
'enable': value,
'download_token': None,
'db_source_code': None,
'download_user': None,
'download_pass': None
}
if 'ip2location' in enable_sources:
enable_sources['ip2location']['download_token'] = (
conf_sources.data['ip2l_download_token']
)
enable_sources['ip2location']['db_source_code'] = (
conf_sources.data['ip2l_database_code']
)
if 'githmptoday' in enable_sources:
enable_sources['githmptoday']['download_user'] = (
conf_sources.data['githmptoday_user']
)
enable_sources['githmptoday']['download_pass'] = (
conf_sources.data['githmptoday_pass']
)
logging.basicConfig(
format='%(asctime)s %(levelname)s: %(name)s: %(message)s',
datefmt='%Y-%m-%d_%H.%M.%S',
handlers=[
logging.FileHandler(
filename=log_root + sep + path.splitext(path.basename(__file__))[0] + '.log',
mode='a'
),
logging.StreamHandler()
],
level=log_level
)
if Do.checkroot():
ro = Route(gateways=enable_gateway, db_root_path=db_root_path)
if args['update']:
for key, value in enable_sources.items():
ro.update_db_sources(
name=key,
db_root_path=db_root_path,
db_source_code=enable_sources[key]['db_source_code'],
download_token=enable_sources[key]['download_token'],
download_user=enable_sources[key]['download_user'],
download_pass=enable_sources[key]['download_pass'],
force_download=args['force'],
logger_alias='update sources ' + key
)
ro.update_db_current(logger_alias='update current')
elif args['add']:
ro.do(action='add', imitate=args['imitate'])
elif args['del']:
ro.do(action='delete', imitate=args['imitate'])
else:
logging.info(msg='No start arguments selected. Exit.')
else:
logging.warning(msg='Restart this as root!')
time_execute = datetime.now() - time_start
logging.info(msg='execution time is ' + str(time_execute) + '. Exit.')