You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
254 lines
8.3 KiB
254 lines
8.3 KiB
#!/usr/bin/env python
|
|
|
|
import csv
|
|
import errno
|
|
import json
|
|
import os
|
|
import traceback
|
|
import uuid
|
|
from distutils.version import StrictVersion
|
|
import onnxruntime as rt
|
|
import numpy as np
|
|
|
|
import cexc
|
|
from util.constants import HOWTO_CONFIGURE_MLSPL_LIMITS
|
|
import util.models_util as models_util
|
|
import util.onnx_util as onnx_util
|
|
from codec import MLSPLEncoder, MLSPLDecoder
|
|
from util.algos import initialize_algo_class
|
|
from util.base_util import get_staging_area_path
|
|
from util.constants import DEFAULT_LOOKUPS_DIR
|
|
from util.lookups_util import (
|
|
file_name_to_path,
|
|
lookup_name_to_path_from_splunk,
|
|
parse_model_reply,
|
|
)
|
|
from util.telemetry_util import (
|
|
log_experiment_details,
|
|
log_apply_details,
|
|
log_example_details,
|
|
log_app_details,
|
|
log_model_id,
|
|
)
|
|
|
|
logger = cexc.get_logger(__name__)
|
|
messages = cexc.get_messages_logger()
|
|
|
|
model_staging_dir = get_staging_area_path()
|
|
|
|
|
|
def load_model(
|
|
model_name,
|
|
searchinfo,
|
|
namespace=None,
|
|
model_dir=DEFAULT_LOOKUPS_DIR,
|
|
tmp=False,
|
|
skip_model_obj=False,
|
|
):
|
|
if tmp:
|
|
file_path = file_name_to_path(
|
|
models_util.model_name_to_filename(model_name, tmp), model_dir
|
|
) # raises if invalid
|
|
else:
|
|
file_name = models_util.model_name_to_filename(model_name)
|
|
file_path = lookup_name_to_path_from_splunk(
|
|
lookup_name=model_name,
|
|
file_name=file_name,
|
|
searchinfo=searchinfo,
|
|
namespace=namespace,
|
|
lookup_type='model',
|
|
)
|
|
|
|
logger.debug('Loading model: %s' % file_path)
|
|
algo_name, model_data, model_options = models_util.load_algo_options_from_disk(
|
|
file_path=file_path
|
|
)
|
|
if skip_model_obj:
|
|
model_obj = None
|
|
else:
|
|
algo_class = initialize_algo_class(algo_name, searchinfo)
|
|
|
|
if hasattr(algo_class, 'register_codecs'):
|
|
algo_class.register_codecs()
|
|
model_obj = decode(model_data['model'])
|
|
|
|
# Convert pre 2.2 variable names to feature_variables and target_variable
|
|
model_obj, model_options = convert_variable_names(model_obj, model_options)
|
|
app_name = searchinfo.get('app')
|
|
log_app_details(app_name)
|
|
log_apply_details(app_name, algo_name, model_options)
|
|
return algo_name, model_obj, model_options
|
|
|
|
|
|
def get_model_options_from_disk(model_name, searchinfo, namespace):
|
|
lookup_name = onnx_util.onnx_model_name_to_lookup(model_name)
|
|
file_path = lookup_name_to_path_from_splunk(
|
|
lookup_name=model_name,
|
|
file_name=lookup_name,
|
|
searchinfo=searchinfo,
|
|
namespace=namespace,
|
|
lookup_type='model',
|
|
)
|
|
|
|
algo_name, model_data, model_options = models_util.load_algo_options_from_disk(
|
|
file_path=file_path
|
|
)
|
|
return algo_name, model_data, model_options
|
|
|
|
|
|
def get_session_object(model_data):
|
|
"""
|
|
Returns session object using onnxruntime library
|
|
Args:
|
|
model_data: (dict) model info.
|
|
|
|
Returns: (object) session of
|
|
|
|
"""
|
|
try:
|
|
# The indexes are in place to remove the byte chars (b')
|
|
session = rt.InferenceSession(onnx_util.decode_model_file(model_data["model"][2:-1]))
|
|
return session
|
|
except Exception as e:
|
|
raise RuntimeError(str(e))
|
|
|
|
|
|
def load_onnx_model(model_name, searchinfo, namespace=None):
|
|
_, model_data, _ = get_model_options_from_disk(model_name, searchinfo, namespace)
|
|
return get_session_object(model_data)
|
|
|
|
|
|
def get_model_from_btool_result(btool_dict, model_name, user, app, roles, namespace):
|
|
try:
|
|
if namespace == 'user' and user in btool_dict and model_name in btool_dict[user]:
|
|
result = btool_dict[user][model_name]
|
|
user_match_str = os.path.join(
|
|
'users', user, app, 'lookups', models_util.model_name_to_filename(model_name)
|
|
)
|
|
# Here only models in the user namespace is checked, because there is a issue/bug with btool
|
|
# if username is also a role name in Splunk (e.g. username=power and there is the "power" role),
|
|
# btool might return objects that the user have no permission on but role does.
|
|
if result.endswith(user_match_str):
|
|
return result
|
|
|
|
app_match_str = os.path.join(
|
|
'apps', app, 'lookups', models_util.model_name_to_filename(model_name)
|
|
)
|
|
merged_result = None
|
|
for role in roles:
|
|
try:
|
|
path = btool_dict[role].pop(model_name)
|
|
if path.endswith(app_match_str):
|
|
return path
|
|
else:
|
|
# If "app:" is not used, check global namespace
|
|
if namespace != 'app' and (merged_result is None or merged_result < path):
|
|
merged_result = path
|
|
except KeyError:
|
|
continue # Do Nothing, go to next item
|
|
except Exception:
|
|
cexc.log_traceback()
|
|
raise Exception("Please check mlspl.log for more details.")
|
|
return merged_result
|
|
|
|
|
|
def save_model(
|
|
model_name,
|
|
algo,
|
|
algo_name,
|
|
options,
|
|
max_size=None,
|
|
model_dir=model_staging_dir,
|
|
tmp=False,
|
|
searchinfo=None,
|
|
namespace=None,
|
|
local=False,
|
|
):
|
|
if algo:
|
|
algo_class = type(algo)
|
|
if hasattr(algo_class, 'register_codecs'):
|
|
algo_class.register_codecs()
|
|
opaque = encode(algo)
|
|
else:
|
|
opaque = ''
|
|
|
|
if max_size and max_size > 0 and len(opaque) > max_size * 1024 * 1024:
|
|
raise RuntimeError(
|
|
"Model exceeds size limit ({} > {}). {}".format(
|
|
len(opaque), max_size * 1024 * 1024, HOWTO_CONFIGURE_MLSPL_LIMITS
|
|
)
|
|
)
|
|
|
|
try:
|
|
os.makedirs(model_dir)
|
|
except OSError as e:
|
|
if e.errno == errno.EEXIST and os.path.isdir(model_dir):
|
|
pass
|
|
else:
|
|
# TODO: Log traceback
|
|
raise Exception("Error creating model: %s, %s" % (model_name, e))
|
|
|
|
# if we're creating a real model, generate a random name for it to avoid collisions in the upload staging area
|
|
model_name_to_open = (
|
|
model_name if (tmp or local) else '_' + str(uuid.uuid1()).replace('-', '_')
|
|
)
|
|
file_path = file_name_to_path(
|
|
models_util.model_name_to_filename(model_name_to_open, tmp), model_dir
|
|
) # raises if invalid
|
|
logger.debug(f'Saving model "{model_name}" to {file_path}')
|
|
log_experiment_details(model_name)
|
|
log_example_details(model_name)
|
|
log_model_id(model_name)
|
|
|
|
with open(file_path, mode='w') as f:
|
|
model_writer = csv.writer(f)
|
|
|
|
# TODO: Version attribute
|
|
model_writer.writerow(['algo', 'model', 'options'])
|
|
model_writer.writerow([algo_name, opaque, json.dumps(options)])
|
|
|
|
if not (tmp or local):
|
|
model_filename = models_util.model_name_to_filename(model_name)
|
|
reply = models_util.move_model_file_from_staging(
|
|
model_filename, searchinfo, namespace, f.name
|
|
)
|
|
if not reply.get('success'):
|
|
parse_model_reply(reply)
|
|
|
|
|
|
def encode(obj):
|
|
if StrictVersion(np.version.version) >= StrictVersion('1.10.0'):
|
|
return MLSPLEncoder().encode(obj)
|
|
else:
|
|
raise RuntimeError(
|
|
'Python for Scientific Computing version 1.1 or later is required to save models.'
|
|
)
|
|
|
|
|
|
def decode(payload):
|
|
if StrictVersion(np.version.version) >= StrictVersion('1.10.0'):
|
|
return MLSPLDecoder().decode(payload)
|
|
else:
|
|
raise RuntimeError(
|
|
'Python for Scientific Computing version 1.1 or later is required to load models.'
|
|
)
|
|
|
|
|
|
def convert_variable_names(algo, options):
|
|
"""Convert pre-2.2 models to use new variable names."""
|
|
try:
|
|
if 'explanatory_variables' in options:
|
|
options['feature_variables'] = options.pop('explanatory_variables')
|
|
algo.__dict__['feature_variables'] = algo.__dict__.pop('explanatory_variables')
|
|
|
|
options['target_variable'] = options.pop('variables')
|
|
algo.__dict__['target_variable'] = algo.__dict__.pop('response_variable')
|
|
|
|
elif 'variables' in options:
|
|
options['feature_variables'] = options.pop('variables')
|
|
algo.__dict__['feature_variables'] = algo.__dict__.pop('variables')
|
|
except Exception as e:
|
|
logger.warn(traceback.format_exc())
|
|
raise RuntimeError('Error while converting model variable names: {}'.format(e))
|
|
return algo, options
|