You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
406 lines
15 KiB
406 lines
15 KiB
import json
|
|
import http.client
|
|
import uuid
|
|
import logging
|
|
import copy
|
|
|
|
from util import rest_url_util
|
|
from util.base_util import is_valid_identifier
|
|
from util.constants import EXPERIMENT_MODEL_PREFIX
|
|
from util.searchinfo_util import searchinfo_from_request
|
|
from util.models_util import get_model_list_by_experiment
|
|
from util.lookup_exceptions import ModelNotFoundException
|
|
|
|
from util.rest_proxy import rest_proxy_from_searchinfo
|
|
from util.experiment_util import get_experiment_draft_model_name
|
|
from models.copymodels import copy_model
|
|
from experiment.experiment_validation import validate_experiment_form_args
|
|
from rest.proxy import SplunkRestEndpointProxy, SplunkRestProxyException, SplunkRestException
|
|
|
|
import cexc
|
|
|
|
logger = cexc.get_logger(__name__)
|
|
|
|
MODEL_NAME_ATTR = 'mltk_model_name'
|
|
|
|
|
|
class ExperimentStore(SplunkRestEndpointProxy):
|
|
"""
|
|
API for experiment's conf storage backend
|
|
"""
|
|
|
|
URL_PARTS_PREFIX = ['configs', 'conf-experiments']
|
|
JSON_OUTPUT_FLAG = ('output_mode', 'json')
|
|
|
|
_with_admin_token = False
|
|
_with_raw_result = True
|
|
|
|
@property
|
|
def with_admin_token(self):
|
|
return self._with_admin_token
|
|
|
|
@property
|
|
def with_raw_result(self):
|
|
return self._with_raw_result
|
|
|
|
def _convert_url_parts(self, url_parts):
|
|
"""
|
|
- Mandatory overridden
|
|
- see SplunkRestEndpointProxy._convert_url_parts()
|
|
"""
|
|
|
|
return self.URL_PARTS_PREFIX + url_parts
|
|
|
|
def _transform_request_options(self, rest_options, url_parts, request):
|
|
"""
|
|
- Overridden from SplunkRestEndpointProxy
|
|
- Handling experiment specific modification/handling of the request before
|
|
sending the request to conf endpoint
|
|
- See RestProxy.make_rest_call() for a list of available items for `rest_options`
|
|
|
|
Args:
|
|
rest_options (dict): default rest_options constructed by the request method (get, post, delete)
|
|
url_parts (list): a list of url parts without /mltk/experiments
|
|
request (dict): the original request from the rest call to /mltk/experiments/*
|
|
|
|
Raises:
|
|
SplunkRestProxyException: some error occurred during this process
|
|
|
|
Returns:
|
|
options (tuple) : a two element tuple. the first element is a dictionary that stores parameters needed by
|
|
RestProxy.make_rest_call(), and the second element stores parameters for _handle_reply, if any.
|
|
"""
|
|
|
|
reply_options = {}
|
|
# for GET/DELETE request, we just want to append output=json to the existing getargs
|
|
if rest_options['method'] == 'GET' or rest_options['method'] == 'DELETE':
|
|
rest_options['getargs'] = dict(
|
|
rest_options.get('getargs', []) + [self.JSON_OUTPUT_FLAG]
|
|
)
|
|
|
|
if rest_options['method'] == 'DELETE':
|
|
self._delete_models(request, url_parts)
|
|
# for POST request, we do validation before proxying the request
|
|
if rest_options['method'] == 'POST':
|
|
postargs, blockedargs = self._split_tuple_list(
|
|
request.get('form', []), ["promoteModel"]
|
|
)
|
|
|
|
# skip schema validation for rename operation
|
|
is_rename = postargs.get('exp-operation', '') == 'rename'
|
|
# delete 'exp-operation' attribute to prevent storing
|
|
postargs.pop('exp-operation', '')
|
|
|
|
if not is_rename:
|
|
try:
|
|
validate_experiment_form_args(postargs)
|
|
except Exception as e:
|
|
logger.error(str(e))
|
|
raise SplunkRestProxyException(
|
|
'Cannot validate experiment', logging.ERROR, http.client.BAD_REQUEST
|
|
)
|
|
|
|
postargs['output_mode'] = 'json'
|
|
|
|
if blockedargs.get('promoteModel', None):
|
|
reply_options['promote_model'] = True
|
|
|
|
if len(url_parts) == 0:
|
|
# this is a create POST
|
|
experiment_uuid = str(uuid.uuid4()).replace(
|
|
'-', ''
|
|
) # removing '-' due to model name constraints
|
|
postargs['name'] = experiment_uuid
|
|
|
|
rest_options['postargs'] = postargs
|
|
|
|
return rest_options, reply_options
|
|
|
|
def clone_experiment_models(self, experiment_fetch_reply, request, url_parts):
|
|
"""
|
|
the function performs the "clone models" operation for an experiment, the experiment info is from 'experiment_fetch_reply'
|
|
Args:
|
|
experiment_fetch_reply (dict) : the reply from a mltk/experiments/<guid> POST request
|
|
request (dict) : the request object
|
|
url_parts (list) : a subset of the url, here is a list of length 1 which contains experiment id.
|
|
|
|
Returns:
|
|
(dict) a dictionary of `status` and `payload`
|
|
"""
|
|
target_info = json.loads(request.get('payload', {}))
|
|
if target_info.keys() == {'app', 'name'}:
|
|
target_model_name = target_info.get('name')
|
|
if not is_valid_identifier(target_model_name):
|
|
raise SplunkRestProxyException(
|
|
'Invalid model name "%s"' % target_model_name,
|
|
logging.ERROR,
|
|
http.client.BAD_REQUEST,
|
|
)
|
|
|
|
source_searchinfo = searchinfo_from_request(request)
|
|
target_searchinfo = copy.deepcopy(source_searchinfo)
|
|
target_searchinfo['app'] = target_info.get('app')
|
|
|
|
clone_experiment_model_callback = self._clone_experiment_model_callback(
|
|
source_searchinfo,
|
|
target_searchinfo,
|
|
target_model_name,
|
|
url_parts[0],
|
|
reply_handler=self._add_model_name_to_reply,
|
|
)
|
|
|
|
reply_list = self._handle_all_experiment_models(
|
|
experiment_fetch_reply, clone_experiment_model_callback
|
|
)
|
|
|
|
formatted_reply = self._handle_clone_reply(reply_list)
|
|
|
|
return self._handle_reply(formatted_reply, {}, request, url_parts, 'POST')
|
|
|
|
else:
|
|
raise SplunkRestProxyException(
|
|
'This handler only supports "app" and "name" as arguments',
|
|
logging.ERROR,
|
|
http.client.BAD_REQUEST,
|
|
)
|
|
|
|
@staticmethod
|
|
def _handle_all_experiment_models(reply, callback_handler):
|
|
"""
|
|
pass the callback_handler to each model for all search stages of an experiment, exit if handler returns failure.
|
|
Args:
|
|
reply (dict) : the reply object of an experiment GET request.
|
|
callback_handler (func) : a callback handler for each model, it should return the reply of a REST request.
|
|
|
|
Returns:
|
|
(list): a list of replies from each handlers
|
|
"""
|
|
|
|
try:
|
|
content = json.loads(reply['content'])
|
|
entries = content['entry']
|
|
except Exception:
|
|
cexc.log_traceback()
|
|
raise Exception(
|
|
"Error loading Experiment. Please check mlspl.log for more details."
|
|
)
|
|
# a cache that stores the reply from the callback of each model
|
|
reply_list = []
|
|
|
|
for entry in entries:
|
|
try:
|
|
ss_json = entry['content']['searchStages']
|
|
search_stages = json.loads(ss_json)
|
|
for search_stage in search_stages:
|
|
model_name = search_stage.get('modelName')
|
|
if model_name is not None:
|
|
reply = callback_handler(model_name)
|
|
reply_list.append(reply)
|
|
# if any of the reply is not successful, stop the process and return the current reply list
|
|
if not reply.get('success'):
|
|
return reply_list
|
|
except (ValueError, KeyError):
|
|
cexc.log_traceback()
|
|
raise Exception(
|
|
"Experiment with id '{id}' is incomplete. This could be because this Experiment has not been saved."
|
|
" Please check mlspl.log for more details.".format(id=entry.get('name'))
|
|
)
|
|
return reply_list
|
|
|
|
@staticmethod
|
|
def _clone_experiment_model_callback(
|
|
source_searchinfo,
|
|
target_searchinfo,
|
|
target_base_model_name,
|
|
experiment_id,
|
|
reply_handler=None,
|
|
):
|
|
"""
|
|
a closure function that take the necessary info to clone a model inside an experiment to designated namespace.
|
|
Args:
|
|
source_searchinfo (dict) : the searchinfo from the source experiment request.
|
|
target_searchinfo (dict) : the searchinfo for the target space for the clone destination
|
|
target_base_model_name (str) : new name of the destination model(s)
|
|
experiment_id (str) : id of the source experiment
|
|
reply_handler (func): handler for each reply, optional.
|
|
|
|
Returns:
|
|
(func) a callback to copy_model that takes model_name <string> as an argument
|
|
"""
|
|
|
|
def callback(source_model_name):
|
|
# only replace <guid> with the new model name but keeps the suffixes like "_StandardScaler_0"
|
|
target_model_name = source_model_name.replace(
|
|
EXPERIMENT_MODEL_PREFIX + experiment_id, target_base_model_name, 1
|
|
)
|
|
|
|
try:
|
|
raw_reply = copy_model(
|
|
source_searchinfo, source_model_name, target_searchinfo, target_model_name
|
|
)
|
|
except SplunkRestException as e:
|
|
raw_reply = e.get_raw_reply()
|
|
return reply_handler(raw_reply, target_model_name) if reply_handler else raw_reply
|
|
|
|
return callback
|
|
|
|
@staticmethod
|
|
def _add_model_name_to_reply(raw_reply, model_name):
|
|
"""
|
|
a util function for customize the reply from Splunk lookup-table-file REST endpoint.
|
|
1. if it's a success REST reply, insert type='INFO' and add custom attribute `mltk_model_name` to the
|
|
`messages` parts.
|
|
2. if it's not a success REST reply, only add the custom attribute.
|
|
Args:
|
|
raw_reply (dict) : a dict of raw reply from Splunk lookup-table-file request
|
|
model_name: the model name which needs to be inserted.
|
|
|
|
Returns:
|
|
(dict) modified reply.
|
|
"""
|
|
|
|
reply = copy.deepcopy(raw_reply)
|
|
try:
|
|
content = json.loads(raw_reply['content'])
|
|
messages = content['messages']
|
|
|
|
if len(messages) > 0:
|
|
messages[0][MODEL_NAME_ATTR] = model_name
|
|
else:
|
|
message_success = {'type': "INFO", 'text': '', MODEL_NAME_ATTR: model_name}
|
|
messages.append(message_success)
|
|
|
|
reply['content'] = json.dumps(content)
|
|
|
|
except Exception:
|
|
cexc.log_traceback()
|
|
raise Exception(
|
|
"Invalid JSON response from REST API, Please check mlspl.log for more details."
|
|
)
|
|
|
|
return reply
|
|
|
|
@staticmethod
|
|
def _promote_draft_model_callback(searchinfo):
|
|
"""
|
|
a closure function that take the necessary info to clone a model within the same namespace.
|
|
Args:
|
|
searchinfo:
|
|
|
|
Returns:
|
|
(func) a callback to copy_model that takes model_name <string> as an argument
|
|
"""
|
|
|
|
def callback(model_name):
|
|
draft_model_name = get_experiment_draft_model_name(model_name)
|
|
|
|
try:
|
|
return copy_model(searchinfo, draft_model_name, searchinfo, model_name)
|
|
except ModelNotFoundException as e:
|
|
cexc.log_traceback()
|
|
logger.error(e)
|
|
raise SplunkRestProxyException(
|
|
"%s: %s" % (str(e), draft_model_name), logging.ERROR, http.client.NOT_FOUND
|
|
)
|
|
|
|
return callback
|
|
|
|
@staticmethod
|
|
def _delete_models(request, url_parts):
|
|
if len(url_parts) == 1:
|
|
try:
|
|
searchinfo = searchinfo_from_request(request)
|
|
rest_proxy = rest_proxy_from_searchinfo(searchinfo)
|
|
model_list = get_model_list_by_experiment(
|
|
rest_proxy, namespace='user', experiment_id=url_parts[0]
|
|
)
|
|
for model_name in model_list:
|
|
url = rest_url_util.make_get_lookup_url(
|
|
rest_proxy, namespace='user', lookup_file=model_name
|
|
)
|
|
rest_proxy.make_rest_call('DELETE', url)
|
|
except Exception:
|
|
cexc.log_traceback()
|
|
pass
|
|
|
|
@staticmethod
|
|
def _handle_clone_reply(replies):
|
|
"""
|
|
merge the 'messages' part of all replies into the last reply.
|
|
Args:
|
|
replies (list) : the replies from all splunk REST requests, with ['content']['messages'] modified by _clone_experiment_model_callback()
|
|
|
|
Returns:
|
|
(dict) a modified version of mltk clone reply, trimming all attributes in `content` except 'messages'.
|
|
"""
|
|
|
|
messages = []
|
|
merged_reply = None # set None to throw exception if replies is empty
|
|
try:
|
|
for reply in replies:
|
|
messages.append(json.loads(reply['content'])['messages'][0])
|
|
merged_reply = reply
|
|
if not reply['success']:
|
|
break
|
|
|
|
merged_reply['content'] = json.dumps({'messages': messages})
|
|
except Exception:
|
|
cexc.log_traceback()
|
|
raise Exception(
|
|
"Invalid JSON response from REST API, Please check mlspl.log for more details."
|
|
)
|
|
|
|
return merged_reply
|
|
|
|
def _handle_reply(self, reply, reply_options, request, url_parts, method):
|
|
"""
|
|
- Overridden from SplunkRestEndpointProxy
|
|
- Replace '/configs/conf-experiments' in the reply with '/mltk/experiments'
|
|
|
|
Args:
|
|
reply (dict): the reply we got from '/configs/conf-experiments'
|
|
reply_options (dict): the reply options from '_transform_request_options'
|
|
url_parts (list): a list of url parts without /mltk/experiments
|
|
method (string): original request's method
|
|
|
|
Returns:
|
|
reply: reply from input after the filtering
|
|
"""
|
|
|
|
def deproxy(string):
|
|
# replace '/configs/conf-experiments' with '/mltk/experiments'
|
|
return string.replace('/%s' % '/'.join(self.URL_PARTS_PREFIX), '/mltk/experiments')
|
|
|
|
content = json.loads(reply.get('content'))
|
|
|
|
if content.get('origin'):
|
|
content['origin'] = deproxy(content['origin'])
|
|
|
|
if content.get('links'):
|
|
for key, value in content['links'].items():
|
|
content['links'][key] = deproxy(value)
|
|
|
|
if content.get('entry'):
|
|
entry = content['entry']
|
|
for item in entry:
|
|
item['id'] = deproxy(item['id'])
|
|
for key, value in item['links'].items():
|
|
item['links'][key] = deproxy(value)
|
|
|
|
# promote the draft model to production.
|
|
if (
|
|
reply_options.get('promote_model')
|
|
and method == 'POST'
|
|
and reply.get('status') == http.client.OK
|
|
):
|
|
searchinfo = searchinfo_from_request(request)
|
|
|
|
promote_draft_model_callback = self._promote_draft_model_callback(searchinfo)
|
|
self._handle_all_experiment_models(reply, promote_draft_model_callback)
|
|
|
|
return {
|
|
'status': reply.get('status', http.client.OK),
|
|
'payload': json.dumps(content),
|
|
'headers': {'Content-Type': 'application/json'},
|
|
}
|