You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

527 lines
17 KiB

#
# Copyright 2025 Splunk Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
REST Handler.
"""
import json
import traceback
import urllib.parse
from typing import Optional, Any
from defusedxml import ElementTree
from functools import wraps
from solnlib.splunk_rest_client import SplunkRestClient
from solnlib.utils import is_true
from splunklib import binding
from .credentials import RestCredentials
from .entity import RestEntity
from .error import RestError
__all__ = ["RestHandler"]
BASIC_NAME_VALIDATORS = {
"PROHIBITED_NAME_CHARACTERS": ["*", "\\", "[", "]", "(", ")", "?", ":"],
"PROHIBITED_NAMES": ["default", ".", ".."],
"MAX_LENGTH": 1024,
}
_TA_CONFIG_FILENAME = "_TA_config"
_TA_CONFIG_ENDPOINT = f"configs/conf-{_TA_CONFIG_FILENAME}"
_NEED_RELOAD_PARAMETER = "need_reload"
def _check_name_for_create(name):
if name == "default":
raise RestError(400, '"%s" is not allowed for entity name' % name)
if name.startswith("_"):
raise RestError(400, 'Name starting with "_" is not allowed for entity')
def _parse_error_msg(exc: binding.HTTPError) -> str:
permission_msg = "do not have permission to perform this operation"
try:
msgs = json.loads(exc.body)["messages"]
text = msgs[0]["text"]
except json.JSONDecodeError:
try:
text = ElementTree.fromstring(exc.body).findtext("./messages/msg")
except ElementTree.ParseError:
return exc.body.decode()
except (KeyError, IndexError):
return exc.body.decode()
if exc.status == 403 and permission_msg in text:
return "This operation is forbidden."
return text
def _pre_request(existing):
"""
Encode payload before request.
:param existing:
if True: means must exist
if False: means must NOT exist
:return:
"""
def _pre_request_wrapper(meth):
"""
:param meth: RestHandler instance method
:return:
"""
def check_existing(self, name):
if not existing:
# for create, check name
_check_name_for_create(name)
# check if the entity existed
entities = []
try:
entities = list(self.get(name))
except RestError:
pass
if existing and not entities:
raise RestError(
404,
'"%s" does not exist' % name,
)
elif not existing and entities:
raise RestError(
409,
'Name "%s" is already in use' % name,
)
if entities:
return entities[0].content
else:
return None
def basic_name_validation(name: str):
tmp_name = str(name)
prohibited_chars = BASIC_NAME_VALIDATORS["PROHIBITED_NAME_CHARACTERS"]
prohibited_names = BASIC_NAME_VALIDATORS["PROHIBITED_NAMES"]
max_chars = BASIC_NAME_VALIDATORS["MAX_LENGTH"]
val_err_msg = (
f'{prohibited_names}, string started with "_" and string including any one '
f'of {prohibited_chars} are reserved value which cannot be used for field Name"'
)
if tmp_name.startswith("_") or any(
tmp_name == el for el in prohibited_names
):
raise RestError(400, val_err_msg)
if any(pc in prohibited_chars for pc in tmp_name):
raise RestError(400, val_err_msg)
if len(tmp_name) >= max_chars:
raise RestError(
400, f"Field Name must be less than {max_chars} characters"
)
@wraps(meth)
def wrapper(self, name, data):
self._endpoint.validate(
name,
data,
check_existing(self, name),
)
basic_name_validation(name)
self._endpoint.validate_special(name, data)
self._endpoint.encode(name, data)
return meth(self, name, data)
return wrapper
return _pre_request_wrapper
def _decode_response(meth):
"""
Decode response body.
:param meth: RestHandler instance method
:return:
"""
def decode(self, name, data, acl):
self._endpoint.decode(name, data)
return RestEntity(
name,
data,
self._endpoint.model(name),
self._endpoint.user,
self._endpoint.app,
acl=acl,
)
@wraps(meth)
def wrapper(self, *args, **kwargs):
try:
for name, data, acl in meth(self, *args, **kwargs):
yield decode(self, name, data, acl)
except RestError:
raise
except binding.HTTPError as exc:
raise RestError(exc.status, _parse_error_msg(exc))
except Exception:
raise RestError(500, traceback.format_exc())
return wrapper
class RestHandler:
def __init__(self, splunkd_uri, session_key, endpoint, *args, **kwargs):
self._splunkd_uri = splunkd_uri
self._session_key = session_key
self._endpoint = endpoint
self._args = args
self._kwargs = kwargs
self._conf_name = getattr(endpoint, "conf_name", None)
splunkd_info = urllib.parse.urlparse(self._splunkd_uri)
self._client = SplunkRestClient(
self._session_key,
self._endpoint.app,
scheme=splunkd_info.scheme,
host=splunkd_info.hostname,
port=splunkd_info.port,
)
self.rest_credentials = RestCredentials(
self._splunkd_uri,
self._session_key,
self._endpoint,
)
self.PASSWORD = "******"
@_decode_response
def get(self, name, decrypt=False):
self.reload_if_needed()
response = self._client.get(
self.path_segment(
self._endpoint.internal_endpoint,
name=name,
),
output_mode="json",
)
return self._format_response(response, get=True, decrypt=decrypt)
@_decode_response
def all(self, decrypt=False, **query):
self.reload_if_needed()
response = self._client.get(
self.path_segment(self._endpoint.internal_endpoint),
output_mode="json",
**query,
)
return self._format_all_response(response, decrypt)
def get_encrypted_field_names(self, name):
return [x.name for x in self._endpoint.model(name).fields if x.encrypted]
@_decode_response
@_pre_request(existing=False)
def create(self, name, data):
data["name"] = name
self.rest_credentials.encrypt_for_create(name, data)
response = self._client.post(
self.path_segment(self._endpoint.internal_endpoint),
output_mode="json",
body=data,
)
return self._format_response(response)
@_decode_response
@_pre_request(existing=True)
def update(self, name, data):
self.rest_credentials.encrypt_for_update(name, data)
response = self._client.post(
self.path_segment(
self._endpoint.internal_endpoint,
name=name,
),
output_mode="json",
body=data,
)
return self._format_response(response)
@_decode_response
def delete(self, name):
response = self._client.delete(
self.path_segment(
self._endpoint.internal_endpoint,
name=name,
),
output_mode="json",
)
# delete credentials if there are encrypted fields
if self.get_encrypted_field_names(name):
rest_credentials = RestCredentials(
self._splunkd_uri,
self._session_key,
self._endpoint,
)
rest_credentials.delete(name)
return self._flay_response(response)
@_decode_response
def disable(self, name):
response = self._client.post(
self.path_segment(
self._endpoint.internal_endpoint,
name=name,
action="disable",
),
output_mode="json",
)
return self._flay_response(response)
@_decode_response
def enable(self, name):
response = self._client.post(
self.path_segment(
self._endpoint.internal_endpoint,
name=name,
action="enable",
),
output_mode="json",
)
return self._flay_response(response)
def reload_if_needed(self):
if self._conf_name and self.is_reload_needed():
self.reload()
def is_reload_needed(self) -> bool:
need_reload = self._is_reload_needed()
if need_reload is None:
need_reload = self._endpoint.need_reload
return need_reload
def _is_reload_needed(self) -> Optional[bool]:
name = "config"
try:
response = self._client.get(
self.path_segment(
_TA_CONFIG_ENDPOINT,
name=name,
),
output_mode="json",
)
except binding.HTTPError:
return None
response = json.loads(response.body.read())
if "entry" in response:
for entry in response["entry"]:
if entry["name"] == name and _NEED_RELOAD_PARAMETER in entry["content"]:
need_reload = is_true(entry["content"][_NEED_RELOAD_PARAMETER])
if need_reload is not None:
return need_reload
return None
def reload(self):
self._client.get(
self.path_segment(
self._endpoint.internal_endpoint,
action="_reload",
),
)
def get_endpoint(self):
return self._endpoint
@classmethod
def path_segment(cls, endpoint, name=None, action=None):
"""
Make path segment for given context in Splunk REST format:
<endpoint>/<entity>/<action>
:param endpoint: Splunk REST endpoint, e.g. data/inputs
:param name: entity name for request, "/" will be quoted
:param action: Splunk REST action, e.g. disable, enable
:return:
"""
template = "{endpoint}{entity}{action}"
entity = ""
if name:
# all special characters except "/" will be
# url-encoded in splunklib.binding.UrlEncoded
entity = "/" + name.replace("/", "%2F")
path = template.format(
endpoint=endpoint.strip("/"),
entity=entity,
action="/%s" % action if action else "",
)
return path.strip("/")
def _format_response(self, response, get=False, decrypt=False):
body = response.body.read()
try:
cont = json.loads(body)
except ValueError:
raise RestError(500, "Fail to load response, invalid JSON")
for entry in cont["entry"]:
name = entry["name"]
data = entry["content"]
acl = entry["acl"]
encrypted_field_names = self.get_encrypted_field_names(name)
# encrypt and get clear password for get request
if get:
masked = self.rest_credentials.decrypt_for_get(name, data)
if masked:
self._client.post(
self.path_segment(
self._endpoint.internal_endpoint,
name=name,
),
body=masked,
)
if not decrypt:
# replace clear password with '******'
for field_name in encrypted_field_names:
if field_name in data and data[field_name]:
data[field_name] = self.PASSWORD
yield name, data, acl
def _flay_response(self, response, decrypt=False):
body = response.body.read()
try:
cont = json.loads(body)
except ValueError:
raise RestError(500, "Fail to load response, invalid JSON")
for entry in cont["entry"]:
name = entry["name"]
data = entry["content"]
acl = entry["acl"]
if self._need_decrypt(name, data, decrypt):
self._load_credentials(name, data)
if not decrypt:
self._clean_credentials(name, data)
yield name, data, acl
def _format_all_response(self, response, decrypt=False):
body = response.body.read()
try:
cont = json.loads(body)
except ValueError:
raise RestError(500, "Fail to load response, invalid JSON")
# cont['entry']: collection list, load credentials in one request
if self.get_encrypted_field_names(None):
self._encrypt_raw_credentials(cont["entry"])
if not decrypt:
self._clean_all_credentials(cont["entry"])
for entry in cont["entry"]:
name = entry["name"]
data = entry["content"]
acl = entry["acl"]
yield name, data, acl
def _load_credentials(self, name, data):
rest_credentials = RestCredentials(
self._splunkd_uri, self._session_key, self._endpoint
)
masked = rest_credentials.decrypt(name, data)
if masked:
# passwords.conf changed
self._client.post(
self.path_segment(
self._endpoint.internal_endpoint,
name=name,
),
**masked,
)
def _encrypt_raw_credentials(self, data):
rest_credentials = RestCredentials(
self._splunkd_uri, self._session_key, self._endpoint
)
# get clear passwords for response data and get the password change list
change_list = rest_credentials.decrypt_all(data)
field_names = self.get_encrypted_field_names(None)
for model in change_list:
# only updates the defined fields in schema
masked = dict()
for field in field_names:
if (
field in model["content"]
and model["content"][field] != ""
and model["content"][field] != self.PASSWORD
):
masked[field] = self.PASSWORD
if masked:
self._client.post(
self.path_segment(
self._endpoint.internal_endpoint,
name=model["name"],
),
body=masked,
)
def _need_decrypt(self, name, data, decrypt):
# some encrypted-needed fields are plain text in *.conf.
encrypted_field = False
for field in self._endpoint.model(name).fields:
if field.encrypted is False:
# ignore non-encrypted fields
continue
encrypted_field = True
if not data.get(field.name):
# ignore un-stored/empty fields
continue
if data[field.name] == RestCredentials.PASSWORD:
# ignore already-encrypted fields
continue
return True
if decrypt and encrypted_field:
# clear credentials is required by request and
# there are some encrypted-needed fields
return True
return False
def _clean_credentials(self, name, data):
encrypted_field_names = self.get_encrypted_field_names(name)
for field_name in encrypted_field_names:
if field_name in data:
del data[field_name]
def _clean_all_credentials(self, data):
encrypted_field_names = self.get_encrypted_field_names(None)
for model in data:
for field_name in encrypted_field_names:
if (
field_name in model["content"]
and model["content"][field_name] != ""
):
model["content"][field_name] = self.PASSWORD