You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

406 lines
15 KiB

import json
import http.client
import uuid
import logging
import copy
from util import rest_url_util
from util.base_util import is_valid_identifier
from util.constants import EXPERIMENT_MODEL_PREFIX
from util.searchinfo_util import searchinfo_from_request
from util.models_util import get_model_list_by_experiment
from util.lookup_exceptions import ModelNotFoundException
from util.rest_proxy import rest_proxy_from_searchinfo
from util.experiment_util import get_experiment_draft_model_name
from models.copymodels import copy_model
from experiment.experiment_validation import validate_experiment_form_args
from rest.proxy import SplunkRestEndpointProxy, SplunkRestProxyException, SplunkRestException
import cexc
logger = cexc.get_logger(__name__)
MODEL_NAME_ATTR = 'mltk_model_name'
class ExperimentStore(SplunkRestEndpointProxy):
"""
API for experiment's conf storage backend
"""
URL_PARTS_PREFIX = ['configs', 'conf-experiments']
JSON_OUTPUT_FLAG = ('output_mode', 'json')
_with_admin_token = False
_with_raw_result = True
@property
def with_admin_token(self):
return self._with_admin_token
@property
def with_raw_result(self):
return self._with_raw_result
def _convert_url_parts(self, url_parts):
"""
- Mandatory overridden
- see SplunkRestEndpointProxy._convert_url_parts()
"""
return self.URL_PARTS_PREFIX + url_parts
def _transform_request_options(self, rest_options, url_parts, request):
"""
- Overridden from SplunkRestEndpointProxy
- Handling experiment specific modification/handling of the request before
sending the request to conf endpoint
- See RestProxy.make_rest_call() for a list of available items for `rest_options`
Args:
rest_options (dict): default rest_options constructed by the request method (get, post, delete)
url_parts (list): a list of url parts without /mltk/experiments
request (dict): the original request from the rest call to /mltk/experiments/*
Raises:
SplunkRestProxyException: some error occurred during this process
Returns:
options (tuple) : a two element tuple. the first element is a dictionary that stores parameters needed by
RestProxy.make_rest_call(), and the second element stores parameters for _handle_reply, if any.
"""
reply_options = {}
# for GET/DELETE request, we just want to append output=json to the existing getargs
if rest_options['method'] == 'GET' or rest_options['method'] == 'DELETE':
rest_options['getargs'] = dict(
rest_options.get('getargs', []) + [self.JSON_OUTPUT_FLAG]
)
if rest_options['method'] == 'DELETE':
self._delete_models(request, url_parts)
# for POST request, we do validation before proxying the request
if rest_options['method'] == 'POST':
postargs, blockedargs = self._split_tuple_list(
request.get('form', []), ["promoteModel"]
)
# skip schema validation for rename operation
is_rename = postargs.get('exp-operation', '') == 'rename'
# delete 'exp-operation' attribute to prevent storing
postargs.pop('exp-operation', '')
if not is_rename:
try:
validate_experiment_form_args(postargs)
except Exception as e:
logger.error(str(e))
raise SplunkRestProxyException(
'Cannot validate experiment', logging.ERROR, http.client.BAD_REQUEST
)
postargs['output_mode'] = 'json'
if blockedargs.get('promoteModel', None):
reply_options['promote_model'] = True
if len(url_parts) == 0:
# this is a create POST
experiment_uuid = str(uuid.uuid4()).replace(
'-', ''
) # removing '-' due to model name constraints
postargs['name'] = experiment_uuid
rest_options['postargs'] = postargs
return rest_options, reply_options
def clone_experiment_models(self, experiment_fetch_reply, request, url_parts):
"""
the function performs the "clone models" operation for an experiment, the experiment info is from 'experiment_fetch_reply'
Args:
experiment_fetch_reply (dict) : the reply from a mltk/experiments/<guid> POST request
request (dict) : the request object
url_parts (list) : a subset of the url, here is a list of length 1 which contains experiment id.
Returns:
(dict) a dictionary of `status` and `payload`
"""
target_info = json.loads(request.get('payload', {}))
if target_info.keys() == {'app', 'name'}:
target_model_name = target_info.get('name')
if not is_valid_identifier(target_model_name):
raise SplunkRestProxyException(
'Invalid model name "%s"' % target_model_name,
logging.ERROR,
http.client.BAD_REQUEST,
)
source_searchinfo = searchinfo_from_request(request)
target_searchinfo = copy.deepcopy(source_searchinfo)
target_searchinfo['app'] = target_info.get('app')
clone_experiment_model_callback = self._clone_experiment_model_callback(
source_searchinfo,
target_searchinfo,
target_model_name,
url_parts[0],
reply_handler=self._add_model_name_to_reply,
)
reply_list = self._handle_all_experiment_models(
experiment_fetch_reply, clone_experiment_model_callback
)
formatted_reply = self._handle_clone_reply(reply_list)
return self._handle_reply(formatted_reply, {}, request, url_parts, 'POST')
else:
raise SplunkRestProxyException(
'This handler only supports "app" and "name" as arguments',
logging.ERROR,
http.client.BAD_REQUEST,
)
@staticmethod
def _handle_all_experiment_models(reply, callback_handler):
"""
pass the callback_handler to each model for all search stages of an experiment, exit if handler returns failure.
Args:
reply (dict) : the reply object of an experiment GET request.
callback_handler (func) : a callback handler for each model, it should return the reply of a REST request.
Returns:
(list): a list of replies from each handlers
"""
try:
content = json.loads(reply['content'])
entries = content['entry']
except Exception:
cexc.log_traceback()
raise Exception(
"Error loading Experiment. Please check mlspl.log for more details."
)
# a cache that stores the reply from the callback of each model
reply_list = []
for entry in entries:
try:
ss_json = entry['content']['searchStages']
search_stages = json.loads(ss_json)
for search_stage in search_stages:
model_name = search_stage.get('modelName')
if model_name is not None:
reply = callback_handler(model_name)
reply_list.append(reply)
# if any of the reply is not successful, stop the process and return the current reply list
if not reply.get('success'):
return reply_list
except (ValueError, KeyError):
cexc.log_traceback()
raise Exception(
"Experiment with id '{id}' is incomplete. This could be because this Experiment has not been saved."
" Please check mlspl.log for more details.".format(id=entry.get('name'))
)
return reply_list
@staticmethod
def _clone_experiment_model_callback(
source_searchinfo,
target_searchinfo,
target_base_model_name,
experiment_id,
reply_handler=None,
):
"""
a closure function that take the necessary info to clone a model inside an experiment to designated namespace.
Args:
source_searchinfo (dict) : the searchinfo from the source experiment request.
target_searchinfo (dict) : the searchinfo for the target space for the clone destination
target_base_model_name (str) : new name of the destination model(s)
experiment_id (str) : id of the source experiment
reply_handler (func): handler for each reply, optional.
Returns:
(func) a callback to copy_model that takes model_name <string> as an argument
"""
def callback(source_model_name):
# only replace <guid> with the new model name but keeps the suffixes like "_StandardScaler_0"
target_model_name = source_model_name.replace(
EXPERIMENT_MODEL_PREFIX + experiment_id, target_base_model_name, 1
)
try:
raw_reply = copy_model(
source_searchinfo, source_model_name, target_searchinfo, target_model_name
)
except SplunkRestException as e:
raw_reply = e.get_raw_reply()
return reply_handler(raw_reply, target_model_name) if reply_handler else raw_reply
return callback
@staticmethod
def _add_model_name_to_reply(raw_reply, model_name):
"""
a util function for customize the reply from Splunk lookup-table-file REST endpoint.
1. if it's a success REST reply, insert type='INFO' and add custom attribute `mltk_model_name` to the
`messages` parts.
2. if it's not a success REST reply, only add the custom attribute.
Args:
raw_reply (dict) : a dict of raw reply from Splunk lookup-table-file request
model_name: the model name which needs to be inserted.
Returns:
(dict) modified reply.
"""
reply = copy.deepcopy(raw_reply)
try:
content = json.loads(raw_reply['content'])
messages = content['messages']
if len(messages) > 0:
messages[0][MODEL_NAME_ATTR] = model_name
else:
message_success = {'type': "INFO", 'text': '', MODEL_NAME_ATTR: model_name}
messages.append(message_success)
reply['content'] = json.dumps(content)
except Exception:
cexc.log_traceback()
raise Exception(
"Invalid JSON response from REST API, Please check mlspl.log for more details."
)
return reply
@staticmethod
def _promote_draft_model_callback(searchinfo):
"""
a closure function that take the necessary info to clone a model within the same namespace.
Args:
searchinfo:
Returns:
(func) a callback to copy_model that takes model_name <string> as an argument
"""
def callback(model_name):
draft_model_name = get_experiment_draft_model_name(model_name)
try:
return copy_model(searchinfo, draft_model_name, searchinfo, model_name)
except ModelNotFoundException as e:
cexc.log_traceback()
logger.error(e)
raise SplunkRestProxyException(
"%s: %s" % (str(e), draft_model_name), logging.ERROR, http.client.NOT_FOUND
)
return callback
@staticmethod
def _delete_models(request, url_parts):
if len(url_parts) == 1:
try:
searchinfo = searchinfo_from_request(request)
rest_proxy = rest_proxy_from_searchinfo(searchinfo)
model_list = get_model_list_by_experiment(
rest_proxy, namespace='user', experiment_id=url_parts[0]
)
for model_name in model_list:
url = rest_url_util.make_get_lookup_url(
rest_proxy, namespace='user', lookup_file=model_name
)
rest_proxy.make_rest_call('DELETE', url)
except Exception:
cexc.log_traceback()
pass
@staticmethod
def _handle_clone_reply(replies):
"""
merge the 'messages' part of all replies into the last reply.
Args:
replies (list) : the replies from all splunk REST requests, with ['content']['messages'] modified by _clone_experiment_model_callback()
Returns:
(dict) a modified version of mltk clone reply, trimming all attributes in `content` except 'messages'.
"""
messages = []
merged_reply = None # set None to throw exception if replies is empty
try:
for reply in replies:
messages.append(json.loads(reply['content'])['messages'][0])
merged_reply = reply
if not reply['success']:
break
merged_reply['content'] = json.dumps({'messages': messages})
except Exception:
cexc.log_traceback()
raise Exception(
"Invalid JSON response from REST API, Please check mlspl.log for more details."
)
return merged_reply
def _handle_reply(self, reply, reply_options, request, url_parts, method):
"""
- Overridden from SplunkRestEndpointProxy
- Replace '/configs/conf-experiments' in the reply with '/mltk/experiments'
Args:
reply (dict): the reply we got from '/configs/conf-experiments'
reply_options (dict): the reply options from '_transform_request_options'
url_parts (list): a list of url parts without /mltk/experiments
method (string): original request's method
Returns:
reply: reply from input after the filtering
"""
def deproxy(string):
# replace '/configs/conf-experiments' with '/mltk/experiments'
return string.replace('/%s' % '/'.join(self.URL_PARTS_PREFIX), '/mltk/experiments')
content = json.loads(reply.get('content'))
if content.get('origin'):
content['origin'] = deproxy(content['origin'])
if content.get('links'):
for key, value in content['links'].items():
content['links'][key] = deproxy(value)
if content.get('entry'):
entry = content['entry']
for item in entry:
item['id'] = deproxy(item['id'])
for key, value in item['links'].items():
item['links'][key] = deproxy(value)
# promote the draft model to production.
if (
reply_options.get('promote_model')
and method == 'POST'
and reply.get('status') == http.client.OK
):
searchinfo = searchinfo_from_request(request)
promote_draft_model_callback = self._promote_draft_model_callback(searchinfo)
self._handle_all_experiment_models(reply, promote_draft_model_callback)
return {
'status': reply.get('status', http.client.OK),
'payload': json.dumps(content),
'headers': {'Content-Type': 'application/json'},
}