You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
197 lines
7.0 KiB
197 lines
7.0 KiB
# Copyright (C) 2005-2024 Splunk Inc. All Rights Reserved.
|
|
|
|
from functools import wraps
|
|
import http.client
|
|
import json
|
|
import traceback
|
|
|
|
from routes import Mapper
|
|
from splunk.persistconn.application import PersistentServerConnectionApplication
|
|
from ITOA.setup_logging import getLogger
|
|
|
|
from .session import session
|
|
from .exception import BaseRestException
|
|
|
|
try:
|
|
basestring
|
|
except NameError:
|
|
basestring = str
|
|
|
|
logger = getLogger()
|
|
|
|
HTTP_VERBS = ['get', 'post', 'put', 'patch', 'delete']
|
|
HTTP_STATUS_CODES = set(http.client.responses.keys())
|
|
|
|
|
|
def route(path, methods=None):
|
|
"""
|
|
Decorator method for registering route and route handler.
|
|
MUST be used inside of class whose metaclass is RequestHandlerType
|
|
"""
|
|
if not isinstance(path, basestring):
|
|
raise TypeError('path should be a string')
|
|
|
|
method_set = set()
|
|
if methods is None:
|
|
method_set = {m for m in HTTP_VERBS}
|
|
elif not isinstance(methods, list):
|
|
raise TypeError('methods should be a list')
|
|
else:
|
|
for m in methods:
|
|
m = m.lower()
|
|
if m not in HTTP_VERBS:
|
|
raise TypeError('methods should be one of %s, but instead got %s' % (HTTP_VERBS, m))
|
|
method_set.add(m)
|
|
|
|
def route_register(func):
|
|
func.path = path
|
|
func.allowed_methods = method_set
|
|
func.is_handler = True
|
|
|
|
@wraps(func)
|
|
def wrapper(*args, **kwargs):
|
|
return func(*args, **kwargs)
|
|
return wrapper
|
|
return route_register
|
|
|
|
|
|
class RequestHandlerType(type):
|
|
'''
|
|
Metaclass for class that serves as a REST handler. It registers all the
|
|
methods that are decorated with the @route decorator as route handlers.
|
|
'''
|
|
def __new__(cls, name, bases, dct):
|
|
result = super(RequestHandlerType, cls).__new__(cls, name, bases, dct)
|
|
registered_routes = Mapper()
|
|
|
|
for name, val in dct.items():
|
|
if getattr(val, 'is_handler', 0):
|
|
registered_routes.connect(None, val.path, handler=val)
|
|
result._routes = registered_routes
|
|
return result
|
|
|
|
|
|
class Request(object):
|
|
'''
|
|
Basic REST request class
|
|
'''
|
|
def __init__(self, path, full_path, method, headers, data, query, session):
|
|
self.path = path
|
|
self.full_path = full_path
|
|
self.method = method
|
|
self.headers = headers
|
|
self.data = data
|
|
self.query = query
|
|
self.session = session
|
|
|
|
@staticmethod
|
|
def build_from_splunkd_request(splunkd_req):
|
|
path = splunkd_req.get('path_info', '')
|
|
full_path = splunkd_req['rest_path']
|
|
method = splunkd_req['method']
|
|
headers = dict(splunkd_req['headers'])
|
|
query = dict(splunkd_req.get('query', []))
|
|
session = splunkd_req['session']
|
|
# We store global session-related info in the user session for ease of use
|
|
session['global.system_authtoken'] = splunkd_req.get('system_authtoken')
|
|
if 'Content-Type' in headers and headers['Content-Type'].lower().startswith('application/json'):
|
|
data = json.loads(splunkd_req.get('payload', '{}'))
|
|
else:
|
|
data = dict(splunkd_req.get('form', []))
|
|
return Request(path, full_path, method, headers, data, query, session)
|
|
|
|
def to_json(self):
|
|
return json.dumps(vars(self))
|
|
|
|
|
|
class BaseRestInterfaceSplunkd(PersistentServerConnectionApplication, metaclass=RequestHandlerType):
|
|
"""
|
|
Base REST interface class.
|
|
|
|
It provides session management, request/response parsing and routing by default.
|
|
Subclasses of this class will only need to worry about REST handler method that
|
|
interacts with business logic.
|
|
|
|
NOTE: this class MUST NOT be imported directly into a module that defines the
|
|
REST handler class that inherits from it due to the constraint imposed by splunkd.
|
|
To inherit it, import the rest_interface_splunkd module and refer to this class as
|
|
rest_interface_splunkd.BaseRestInterfaceSplunkd
|
|
"""
|
|
|
|
def __init__(self, command_line, command_arg):
|
|
super(BaseRestInterfaceSplunkd, self).__init__()
|
|
|
|
def handle(self, in_string):
|
|
'''
|
|
Implementation of handle method defined in PersistentServerConnectionApplication.
|
|
'''
|
|
request_obj = self.extract_args(in_string)
|
|
try:
|
|
session.save(**request_obj.session)
|
|
handler, args = self._get_request_handler_and_args(
|
|
request_obj.path,
|
|
request_obj.method
|
|
)
|
|
|
|
self.execute_before_handle_hooks(request_obj)
|
|
status_code, response = handler(self, request_obj, **args)
|
|
return self._response(status_code, response)
|
|
except BaseRestException as e:
|
|
return self._response(e.code, e.msg)
|
|
except Exception as e:
|
|
logger.error('Invalid response - Error: %s' % e.args[0])
|
|
logger.debug(traceback.format_exc())
|
|
return self._response(http.client.INTERNAL_SERVER_ERROR, _('Internal Server Error'))
|
|
finally:
|
|
session.clear()
|
|
|
|
def execute_before_handle_hooks(self, request):
|
|
hooks = getattr(self, '_registered_before_handle_hooks', [])
|
|
for hook in hooks:
|
|
hook.execute(request)
|
|
|
|
def _get_request_handler_and_args(self, path, method):
|
|
matched_handler = self._routes.match('/' + path if path != '' else path)
|
|
route_handler = matched_handler.pop('handler') if matched_handler else None
|
|
|
|
if route_handler:
|
|
if method.lower() not in route_handler.allowed_methods:
|
|
raise BaseRestException(
|
|
http.client.BAD_REQUEST,
|
|
_('REST method %(method)s not allowed') % {'method': method}
|
|
)
|
|
return route_handler, matched_handler
|
|
else:
|
|
raise BaseRestException(http.client.NOT_FOUND, _('Not Found'))
|
|
|
|
def extract_args(self, in_string):
|
|
try:
|
|
args = json.loads(in_string)
|
|
return Request.build_from_splunkd_request(args)
|
|
except Exception as e:
|
|
logger.error('Failed to build request object - error: %s' % e)
|
|
raise BaseRestException(http.client.BAD_REQUEST, _('Bad Request'))
|
|
|
|
def _response(self, status, payload):
|
|
if status not in HTTP_STATUS_CODES:
|
|
logger.error('Invalid status code %s - payload: %s' % (status, payload))
|
|
return self._response_error(http.client.INTERNAL_SERVER_ERROR, _('Internal Server Error'))
|
|
if status >= 400:
|
|
if not isinstance(payload, basestring):
|
|
logger.error('Invalid error message type %s as payload, must be basestring' % type(payload))
|
|
return self._response_error(http.client.INTERNAL_SERVER_ERROR, _('Internal Server Error'))
|
|
return self._response_error(status, payload)
|
|
else:
|
|
return {
|
|
'status': status,
|
|
'payload': payload
|
|
}
|
|
|
|
def _response_error(self, status, message):
|
|
return {
|
|
'status': status,
|
|
'payload': {
|
|
'message': message
|
|
}
|
|
}
|