You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

194 lines
5.0 KiB

# Copyright 2016 Splunk, Inc.
# SPDX-FileCopyrightText: 2020 2020
#
# SPDX-License-Identifier: Apache-2.0
"""
Common utilities.
"""
import datetime
import logging
import os
import signal
import time
import traceback
try:
from urllib import parse as urlparse
except ImportError:
from urllib2 import urlparse
from functools import wraps
__all__ = [
"handle_teardown_signals",
"datetime_to_seconds",
"is_true",
"is_false",
"escape_json_control_chars",
"retry",
"extract_http_scheme_host_port",
]
def handle_teardown_signals(callback):
"""Register handler for SIGTERM/SIGINT/SIGBREAK signal.
Catch SIGTERM/SIGINT/SIGBREAK signals, and invoke callback
Note: this should be called in main thread since Python only catches
signals in main thread.
:param callback: Callback for tear down signals.
:type callback: ``function``
"""
signal.signal(signal.SIGTERM, callback)
signal.signal(signal.SIGINT, callback)
if os.name == "nt":
signal.signal(signal.SIGBREAK, callback)
def datetime_to_seconds(dt):
"""Convert UTC datatime to seconds since epoch.
:param dt: Date time.
:type dt: datetime.
:returns: Seconds since epoch.
:rtype: ``float``
"""
epoch_time = datetime.datetime.utcfromtimestamp(0)
return (dt - epoch_time).total_seconds()
def is_true(val):
"""Decide if `val` is true.
:param val: Value to check.
:type val: ``(integer, string)``
:returns: True or False.
:rtype: ``bool``
"""
value = str(val).strip().upper()
if value in ("1", "TRUE", "T", "Y", "YES"):
return True
return False
def is_false(val):
"""Decide if `val` is false.
:param val: Value to check.
:type val: ``(integer, string)``
:returns: True or False.
:rtype: ``bool``
"""
value = str(val).strip().upper()
if value in ("0", "FALSE", "F", "N", "NO", "NONE", ""):
return True
return False
def escape_json_control_chars(json_str):
"""Escape json control chars in `json_str`.
:param json_str: Json string to escape.
:type json_str: ``string``
:returns: Escaped string.
:rtype: ``string``
"""
control_chars = ((r"\n", "\\\\n"), (r"\r", "\\\\r"), (r"\r\n", "\\\\r\\\\n"))
for ch, replace in control_chars:
json_str = json_str.replace(ch, replace)
return json_str
def unescape_json_control_chars(json_str):
"""Unescape json control chars in `json_str`.
:param json_str: Json string to unescape.
:type json_str: ``string``
:returns: Unescaped string.
:rtype: ``string``
"""
control_chars = (("\\\\n", r"\n"), ("\\\\r", r"\r"), ("\\\\r\\\\n", r"\r\n"))
for ch, replace in control_chars:
json_str = json_str.replace(ch, replace)
return json_str
def retry(retries=3, reraise=True, default_return=None, exceptions=None):
"""A decorator to run function with max `retries` times
if there is exception.
:param retries: (optional) Max retries times, default is 3.
:type retries: ``integer``
:param reraise: Whether exception should be reraised, default is True.
:type reraise: ``bool``
:param default_return: (optional) Default return value for function
run after max retries and reraise is False.
:param exceptions: (optional) List of exceptions that should retry.
:type exceptions: ``list``
"""
max_tries = max(retries, 0) + 1
def do_retry(func):
@wraps(func)
def wrapper(*args, **kwargs):
last_ex = None
for i in range(max_tries):
try:
return func(*args, **kwargs)
except Exception as e:
logging.warning(
"Run function: %s failed: %s.",
func.__name__,
traceback.format_exc(),
)
if not exceptions or any(
isinstance(e, exception) for exception in exceptions
):
last_ex = e
if i < max_tries - 1:
time.sleep(2 ** i)
else:
raise
if reraise:
raise last_ex
else:
return default_return
return wrapper
return do_retry
def extract_http_scheme_host_port(http_url):
"""Extract scheme, host and port from a HTTP URL.
:param http_url: HTTP URL to extract.
:type http_url: ``string``
:returns: A tuple of scheme, host and port
:rtype: ``tuple``
:raises ValueError: If `http_url` is not in http(s)://hostname:port format.
"""
try:
http_info = urlparse.urlparse(http_url)
except Exception:
raise ValueError(str(http_url) + " is not in http(s)://hostname:port format")
if not http_info.scheme or not http_info.hostname or not http_info.port:
raise ValueError(http_url + " is not in http(s)://hostname:port format")
return (http_info.scheme, http_info.hostname, http_info.port)