You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
215 lines
7.8 KiB
215 lines
7.8 KiB
#!/usr/bin/env python
|
|
# Copyright (C) 2015-2019 Splunk Inc. All Rights Reserved.
|
|
import pandas as pd
|
|
|
|
import cexc
|
|
import models
|
|
from .FitBatchProcessor import FitBatchProcessor
|
|
from util.base_util import MLSPLNotImplementedError
|
|
from util.df_util import is_empty_df
|
|
from util.mlspl_loader import MLSPLConf
|
|
from util.lookup_exceptions import ModelNotFoundException
|
|
from util.processor_util import split_options, load_resource_limits
|
|
|
|
logger = cexc.get_logger(__name__)
|
|
messages = cexc.get_messages_logger()
|
|
|
|
|
|
class FitPartialProcessor(FitBatchProcessor):
|
|
"""The fit partial processor receives and returns pandas DataFrames.
|
|
|
|
This processor inherits from FitBatchProcessor and uses a handful of its
|
|
methods. The partial processor does not need sampling and has a few
|
|
additional things it needs to keep track of, including:
|
|
- attempting to load a model
|
|
- checking for discrepancies between search & saved model
|
|
- handling new categorical values as specified by the unseen_value param
|
|
"""
|
|
|
|
def __init__(self, process_options, searchinfo):
|
|
"""Initialize options for processor.
|
|
|
|
Args:
|
|
process_options (dict): process options
|
|
client (SplunkRestProxy): splunk rest bouncer wrapper
|
|
"""
|
|
# Split apart process & algo options
|
|
self.namespace = process_options.pop('namespace', None)
|
|
mlspl_conf = MLSPLConf(searchinfo)
|
|
self.process_options, self.algo_options = split_options(
|
|
process_options, mlspl_conf, process_options['algo_name']
|
|
)
|
|
self.searchinfo = searchinfo
|
|
|
|
# Convenience / readability
|
|
self.tmp_dir = self.process_options['tmp_dir']
|
|
|
|
# Try load algo from a saved model
|
|
self.algo, self.algo_options = self.initialize_algo_from_model(
|
|
self.algo_options, self.searchinfo, namespace=self.namespace
|
|
)
|
|
if self.algo is None:
|
|
# Initialize algo from scratch
|
|
self.algo = self.initialize_algo(self.algo_options, self.searchinfo)
|
|
# Ensure model name is present
|
|
self.check_algo_options(self.algo, self.algo_options)
|
|
else:
|
|
# Check if the loaded model supports partial_fit
|
|
self.check_algo_options(self.algo, self.algo_options)
|
|
# Warn
|
|
self.warn_about_new_parameters()
|
|
|
|
self.save_temp_model(self.algo_options, self.tmp_dir)
|
|
self.resource_limits = load_resource_limits(self.algo_options['algo_name'], mlspl_conf)
|
|
|
|
@staticmethod
|
|
def initialize_algo_from_model(algo_options, searchinfo, namespace):
|
|
"""Init algo from model if possible, and catch discrepancies.
|
|
|
|
Args:
|
|
algo_options (dict): algo options
|
|
searchinfo (dict): information required for search
|
|
namespace (string): namespace, 'user' or 'app'
|
|
Returns:
|
|
algo (object/None): loaded algo or None
|
|
algo_options (dict): algo option
|
|
"""
|
|
algo = None
|
|
if 'model_name' in algo_options:
|
|
try:
|
|
model_algo_name, algo, model_options = models.base.load_model(
|
|
algo_options['model_name'], searchinfo, namespace=namespace
|
|
)
|
|
except ModelNotFoundException:
|
|
algo = None
|
|
except Exception as e:
|
|
cexc.log_traceback()
|
|
raise RuntimeError(
|
|
'Failed to load model "%s". Exception: %s.'
|
|
% (algo_options['model_name'], str(e))
|
|
)
|
|
|
|
if algo is not None:
|
|
FitPartialProcessor.catch_model_discrepancies(
|
|
algo_options, model_options, model_algo_name
|
|
)
|
|
|
|
# Pre 2.2 models do not save algo_name in their model options
|
|
# So we must re add them here to be compatible with 2.2+ versions
|
|
model_options['algo_name'] = algo_options['algo_name']
|
|
algo_options = model_options
|
|
|
|
return algo, algo_options
|
|
|
|
@staticmethod
|
|
def warn_about_new_parameters():
|
|
cexc.messages.warn(
|
|
'Partial fit on existing model ignores newly supplied parameters. '
|
|
'Parameters supplied at model creation are used instead'
|
|
)
|
|
|
|
@staticmethod
|
|
def catch_model_discrepancies(algo_options, model_options, model_algo_name):
|
|
"""Check to see if algo name or input columns are different from the model.
|
|
|
|
Args:
|
|
algo_options (dict): algo options
|
|
model_options (dict): model options
|
|
model_algo_name (str): name of algo from loaded model
|
|
"""
|
|
# Check for discrepancy between algorithm name of the model loaded and algorithm name specified in input
|
|
try:
|
|
assert algo_options['algo_name'] == model_algo_name
|
|
except AssertionError:
|
|
raise RuntimeError(
|
|
"Model was trained using algorithm %s but found %s in input"
|
|
% (model_algo_name, algo_options['algo_name'])
|
|
)
|
|
|
|
# Check for discrepancy between model columns and input columns
|
|
model_features = model_options.get('feature_variables', [])
|
|
algo_features = algo_options.get('feature_variables', [])
|
|
|
|
for var in algo_features:
|
|
if var not in model_features:
|
|
raise RuntimeError(
|
|
"Model was trained on data with different columns than given input. {} in algo_features but not in model_features".format(
|
|
var
|
|
)
|
|
)
|
|
|
|
@staticmethod
|
|
def check_algo_options(algo, algo_options):
|
|
"""Validate processor options.
|
|
|
|
Args:
|
|
algo (object): initialized algo
|
|
algo_options (dict): algo options
|
|
"""
|
|
if 'model_name' not in algo_options:
|
|
raise RuntimeError(
|
|
'You must save a model if you fit the model with partial_fit enabled'
|
|
)
|
|
|
|
if algo_options.get('kfold_cv') is not None:
|
|
raise RuntimeError('kfold_cv cannot be used with partial_fit')
|
|
|
|
@staticmethod
|
|
def fit(algo, df, options):
|
|
"""Perform the partial fit.
|
|
|
|
Args:
|
|
algo (object): algo object
|
|
df (dataframe): dataframe to fit on
|
|
options (dict): process options
|
|
|
|
Returns:
|
|
algo (object): updated algorithm
|
|
"""
|
|
if not is_empty_df(df):
|
|
try:
|
|
algo.partial_fit(df, options)
|
|
except MLSPLNotImplementedError:
|
|
raise RuntimeError(
|
|
'Algorithm "%s" does not support partial fit' % options['algo_name']
|
|
)
|
|
except Exception as e:
|
|
cexc.log_traceback()
|
|
raise RuntimeError(
|
|
'Error while fitting "%s" model: %s' % (options['algo_name'], str(e))
|
|
)
|
|
|
|
return algo
|
|
|
|
def receive_input(self, df):
|
|
"""Override FitBatchProcessor, simply attach df to self.
|
|
|
|
Args:
|
|
df (dataframe): dataframe to receive
|
|
"""
|
|
self.df = df
|
|
|
|
def process(self):
|
|
"""Run fit and update algo."""
|
|
if not is_empty_df(self.df):
|
|
self.algo = self.match_and_assign_variables(
|
|
self.searchinfo.get('app'), self.df.columns, self.algo, self.algo_options
|
|
)
|
|
self.algo = self.fit(self.algo, self.df, self.algo_options)
|
|
|
|
def get_output(self):
|
|
"""Predict if necessary & return appropriate dataframe.
|
|
|
|
Returns:
|
|
(dataframe): output dataframe
|
|
"""
|
|
if not is_empty_df(self.df):
|
|
self.df = self.algo.apply(self.df, self.algo_options)
|
|
if self.df is None:
|
|
messages.warn('Apply method did not return any results.')
|
|
self.df = pd.DataFrame()
|
|
else:
|
|
messages.warning("The dataset has no events.")
|
|
|
|
return self.df
|