You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

254 lines
8.3 KiB

#!/usr/bin/env python
import csv
import errno
import json
import os
import traceback
import uuid
from distutils.version import StrictVersion
import onnxruntime as rt
import numpy as np
import cexc
from util.constants import HOWTO_CONFIGURE_MLSPL_LIMITS
import util.models_util as models_util
import util.onnx_util as onnx_util
from codec import MLSPLEncoder, MLSPLDecoder
from util.algos import initialize_algo_class
from util.base_util import get_staging_area_path
from util.constants import DEFAULT_LOOKUPS_DIR
from util.lookups_util import (
file_name_to_path,
lookup_name_to_path_from_splunk,
parse_model_reply,
)
from util.telemetry_util import (
log_experiment_details,
log_apply_details,
log_example_details,
log_app_details,
log_model_id,
)
logger = cexc.get_logger(__name__)
messages = cexc.get_messages_logger()
model_staging_dir = get_staging_area_path()
def load_model(
model_name,
searchinfo,
namespace=None,
model_dir=DEFAULT_LOOKUPS_DIR,
tmp=False,
skip_model_obj=False,
):
if tmp:
file_path = file_name_to_path(
models_util.model_name_to_filename(model_name, tmp), model_dir
) # raises if invalid
else:
file_name = models_util.model_name_to_filename(model_name)
file_path = lookup_name_to_path_from_splunk(
lookup_name=model_name,
file_name=file_name,
searchinfo=searchinfo,
namespace=namespace,
lookup_type='model',
)
logger.debug('Loading model: %s' % file_path)
algo_name, model_data, model_options = models_util.load_algo_options_from_disk(
file_path=file_path
)
if skip_model_obj:
model_obj = None
else:
algo_class = initialize_algo_class(algo_name, searchinfo)
if hasattr(algo_class, 'register_codecs'):
algo_class.register_codecs()
model_obj = decode(model_data['model'])
# Convert pre 2.2 variable names to feature_variables and target_variable
model_obj, model_options = convert_variable_names(model_obj, model_options)
app_name = searchinfo.get('app')
log_app_details(app_name)
log_apply_details(app_name, algo_name, model_options)
return algo_name, model_obj, model_options
def get_model_options_from_disk(model_name, searchinfo, namespace):
lookup_name = onnx_util.onnx_model_name_to_lookup(model_name)
file_path = lookup_name_to_path_from_splunk(
lookup_name=model_name,
file_name=lookup_name,
searchinfo=searchinfo,
namespace=namespace,
lookup_type='model',
)
algo_name, model_data, model_options = models_util.load_algo_options_from_disk(
file_path=file_path
)
return algo_name, model_data, model_options
def get_session_object(model_data):
"""
Returns session object using onnxruntime library
Args:
model_data: (dict) model info.
Returns: (object) session of
"""
try:
# The indexes are in place to remove the byte chars (b')
session = rt.InferenceSession(onnx_util.decode_model_file(model_data["model"][2:-1]))
return session
except Exception as e:
raise RuntimeError(str(e))
def load_onnx_model(model_name, searchinfo, namespace=None):
_, model_data, _ = get_model_options_from_disk(model_name, searchinfo, namespace)
return get_session_object(model_data)
def get_model_from_btool_result(btool_dict, model_name, user, app, roles, namespace):
try:
if namespace == 'user' and user in btool_dict and model_name in btool_dict[user]:
result = btool_dict[user][model_name]
user_match_str = os.path.join(
'users', user, app, 'lookups', models_util.model_name_to_filename(model_name)
)
# Here only models in the user namespace is checked, because there is a issue/bug with btool
# if username is also a role name in Splunk (e.g. username=power and there is the "power" role),
# btool might return objects that the user have no permission on but role does.
if result.endswith(user_match_str):
return result
app_match_str = os.path.join(
'apps', app, 'lookups', models_util.model_name_to_filename(model_name)
)
merged_result = None
for role in roles:
try:
path = btool_dict[role].pop(model_name)
if path.endswith(app_match_str):
return path
else:
# If "app:" is not used, check global namespace
if namespace != 'app' and (merged_result is None or merged_result < path):
merged_result = path
except KeyError:
continue # Do Nothing, go to next item
except Exception:
cexc.log_traceback()
raise Exception("Please check mlspl.log for more details.")
return merged_result
def save_model(
model_name,
algo,
algo_name,
options,
max_size=None,
model_dir=model_staging_dir,
tmp=False,
searchinfo=None,
namespace=None,
local=False,
):
if algo:
algo_class = type(algo)
if hasattr(algo_class, 'register_codecs'):
algo_class.register_codecs()
opaque = encode(algo)
else:
opaque = ''
if max_size and max_size > 0 and len(opaque) > max_size * 1024 * 1024:
raise RuntimeError(
"Model exceeds size limit ({} > {}). {}".format(
len(opaque), max_size * 1024 * 1024, HOWTO_CONFIGURE_MLSPL_LIMITS
)
)
try:
os.makedirs(model_dir)
except OSError as e:
if e.errno == errno.EEXIST and os.path.isdir(model_dir):
pass
else:
# TODO: Log traceback
raise Exception("Error creating model: %s, %s" % (model_name, e))
# if we're creating a real model, generate a random name for it to avoid collisions in the upload staging area
model_name_to_open = (
model_name if (tmp or local) else '_' + str(uuid.uuid1()).replace('-', '_')
)
file_path = file_name_to_path(
models_util.model_name_to_filename(model_name_to_open, tmp), model_dir
) # raises if invalid
logger.debug(f'Saving model "{model_name}" to {file_path}')
log_experiment_details(model_name)
log_example_details(model_name)
log_model_id(model_name)
with open(file_path, mode='w') as f:
model_writer = csv.writer(f)
# TODO: Version attribute
model_writer.writerow(['algo', 'model', 'options'])
model_writer.writerow([algo_name, opaque, json.dumps(options)])
if not (tmp or local):
model_filename = models_util.model_name_to_filename(model_name)
reply = models_util.move_model_file_from_staging(
model_filename, searchinfo, namespace, f.name
)
if not reply.get('success'):
parse_model_reply(reply)
def encode(obj):
if StrictVersion(np.version.version) >= StrictVersion('1.10.0'):
return MLSPLEncoder().encode(obj)
else:
raise RuntimeError(
'Python for Scientific Computing version 1.1 or later is required to save models.'
)
def decode(payload):
if StrictVersion(np.version.version) >= StrictVersion('1.10.0'):
return MLSPLDecoder().decode(payload)
else:
raise RuntimeError(
'Python for Scientific Computing version 1.1 or later is required to load models.'
)
def convert_variable_names(algo, options):
"""Convert pre-2.2 models to use new variable names."""
try:
if 'explanatory_variables' in options:
options['feature_variables'] = options.pop('explanatory_variables')
algo.__dict__['feature_variables'] = algo.__dict__.pop('explanatory_variables')
options['target_variable'] = options.pop('variables')
algo.__dict__['target_variable'] = algo.__dict__.pop('response_variable')
elif 'variables' in options:
options['feature_variables'] = options.pop('variables')
algo.__dict__['feature_variables'] = algo.__dict__.pop('variables')
except Exception as e:
logger.warn(traceback.format_exc())
raise RuntimeError('Error while converting model variable names: {}'.format(e))
return algo, options