You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

321 lines
11 KiB

#!/usr/bin/env python
# Copyright (C) 2015-2019 Splunk Inc. All Rights Reserved.
import pandas as pd
import cexc
from .BaseProcessor import BaseProcessor
import models.base
from models import deletemodels
from util.base_util import match_field_globs
from util.base_util import MLSPLNotImplementedError
from util.df_util import is_empty_df
from util.algos import initialize_algo_class
from util.mlspl_loader import MLSPLConf
from util.telemetry_util import log_algo_details
from base import ClassifierMixin, RegressorMixin
from util.processor_util import (
split_options,
load_resource_limits,
load_sampler_limits,
get_sampler,
check_sampler,
)
logger = cexc.get_logger(__name__)
messages = cexc.get_messages_logger()
class FitBatchProcessor(BaseProcessor):
"""The fit batch processor receives and returns pandas DataFrames."""
def __init__(self, process_options, searchinfo):
"""Initialize options for processor.
Args:
process_options (dict): process options
searchinfo (dict): information required for search
"""
# Split apart process & algo options
self.namespace = process_options.pop('namespace', None)
mlspl_conf = MLSPLConf(searchinfo)
self.process_options, self.algo_options = split_options(
process_options, mlspl_conf, process_options['algo_name']
)
self.searchinfo = searchinfo
# Convenience / readability
self.tmp_dir = self.process_options['tmp_dir']
self.algo = self.initialize_algo(self.algo_options, self.searchinfo)
self.check_algo_options(self.algo, self.algo_options)
self.save_temp_model(self.algo_options, self.tmp_dir)
self.resource_limits = load_resource_limits(self.algo_options['algo_name'], mlspl_conf)
self._sampler_time = 0.0
self.sampler_limits = load_sampler_limits(
self.process_options, self.algo_options['algo_name'], mlspl_conf
)
self.sampler = get_sampler(self.sampler_limits)
@staticmethod
def initialize_algo(algo_options, searchinfo):
algo_name = algo_options['algo_name']
try:
algo_class = initialize_algo_class(algo_name, searchinfo)
return algo_class(algo_options)
except Exception as e:
cexc.log_traceback()
raise RuntimeError(
'Error while initializing algorithm "%s": %s' % (algo_name, str(e))
)
@staticmethod
def check_algo_options(algo, algo_options):
"""Raise errors if options are incompatible
Args:
algo (object): initialized algo object
algo_options (dict): algo options
Raises:
RuntimeError
"""
# Pre-validate whether or not this algo supports saved models.
if 'model_name' in algo_options:
if algo_options.get('kfold_cv') is not None:
raise RuntimeError('The kfold_cv option cannot be used when saving a model')
try:
algo.register_codecs()
except MLSPLNotImplementedError:
raise RuntimeError(
'Algorithm "%s" does not support saved models' % algo_options['algo_name']
)
except Exception as e:
logger.debug(
"Error while calling algorithm's register_codecs method. {}".format(str(e))
)
raise RuntimeError(
'Error while initializing algorithm. See search.log for details.'
)
can_use_kfold_cv = isinstance(algo, (ClassifierMixin, RegressorMixin))
if algo_options.get('kfold_cv') is not None:
if not can_use_kfold_cv:
raise RuntimeError(
'Algorithm "%s" does not support the kfold_cv parameter'
% algo_options['algo_name']
)
if algo_options.get('kfold_cv') <= 1:
raise RuntimeError('kfold_cv must be > 1')
@staticmethod
def match_and_assign_variables(app_name, columns, algo, algo_options):
"""Match field globs and attach variables to algo.
Args:
app_name (str): application name which runs the fit()
columns (list): columns from dataframe
algo (object): initialized algo object
algo_options (dict): algo options
"""
if hasattr(algo, 'feature_variables'):
algo.feature_variables = match_field_globs(columns, algo.feature_variables)
log_algo_details(app_name, algo, algo_options)
else:
algo.feature_variables = []
# Batch fit
if 'target_variable' in algo_options:
target_variable = algo_options['target_variable'][0]
if target_variable in algo.feature_variables:
algo.feature_variables.remove(target_variable)
# Partial fit
elif hasattr(algo, 'target_variable'):
if algo.target_variable in algo.feature_variables:
algo.feature_variables.remove(algo.target_variable)
return algo
@staticmethod
def save_temp_model(algo_options, tmp_dir):
"""Save temp model for follow-up apply.
Args:
algo_options (dict): algo options
tmp_dir (str): temp directory to save model to
"""
if 'model_name' in algo_options:
try:
models.base.save_model(
algo_options['model_name'],
None,
algo_options['algo_name'],
algo_options,
model_dir=tmp_dir,
tmp=True,
)
except Exception as e:
cexc.log_traceback()
raise RuntimeError(
'Error while saving temporary model "%s": %s'
% (algo_options['model_name'], e)
)
def get_relevant_fields(self):
"""Ask algo for relevant variables and return as relevant fields.
Returns:
relevant_fields (list): relevant fields
"""
relevant_fields = []
if 'feature_variables' in self.algo_options:
self.algo.feature_variables = self.algo_options['feature_variables']
relevant_fields.extend(self.algo_options['feature_variables'])
if 'target_variable' in self.algo_options:
self.algo.target_variable = self.algo_options['target_variable'][0]
relevant_fields.extend(self.algo_options['target_variable'])
if 'split_by' in self.algo_options:
relevant_fields.extend(self.algo_options['split_by'])
return relevant_fields
def save_model(self):
"""Attempt to save the model, delete the temporary model."""
if not is_empty_df(self.df) and 'model_name' in self.algo_options:
try:
models.base.save_model(
self.algo_options['model_name'],
self.algo,
self.algo_options['algo_name'],
self.algo_options,
max_size=self.resource_limits['max_model_size_mb'],
searchinfo=self.searchinfo,
namespace=self.namespace,
)
except Exception as e:
cexc.log_traceback()
raise RuntimeError(
'Error while saving model "%s": %s' % (self.algo_options['model_name'], e)
)
try:
deletemodels.delete_model(
self.algo_options['model_name'], model_dir=self.tmp_dir, tmp=True
)
except Exception as e:
cexc.log_traceback()
logger.warn(
'Exception while deleting tmp model "%s": %s',
self.algo_options['model_name'],
e,
)
def receive_input(self, df):
"""Receive dataframe and append to sampler if necessary.
Args:
df (dataframe): dataframe received from controller
"""
if (
self.sampler_limits['sample_count'] - len(df)
< self.sampler.count
<= self.sampler_limits['sample_count']
):
check_sampler(
sampler_limits=self.sampler_limits, class_name=self.algo_options['algo_name']
)
with cexc.Timer() as sampler_t:
self.sampler.append(df)
self._sampler_time += sampler_t.interval
logger.debug('sampler_time=%f', sampler_t.interval)
def process(self):
"""Get dataframe, update algo, and possibly make predictions."""
self.df = self.sampler.get_df()
if not is_empty_df(self.df):
self.algo = FitBatchProcessor.match_and_assign_variables(
self.searchinfo.get('app'), self.df.columns, self.algo, self.algo_options
)
self.algo, self.df, self.has_applied = self.fit(
self.df, self.algo, self.algo_options
)
@staticmethod
def fit(df, algo, algo_options):
"""Perform the literal fitting process.
This method updates the algo by fitting with input data. Some of the
algorithms additionally make predictions within their fit method, thus
the predictions are returned in dataframe type. Some other algorithms do
not make prediction in their fit method, thus None is returned.
Args:
df (dataframe): dataframe to fit the algo
algo (object): initialized/loaded algo object
algo_options (dict): algo options
Returns:
algo (object): updated algo object
df (dataframe):
- if algo.fit makes prediction, return prediction
- if algo.fit does not make prediction, return input df
has_applied (bool): flag to indicate whether df represents
original df or prediction df
"""
has_applied = False
if not is_empty_df(df):
try:
prediction_df = algo.fit(df, algo_options)
except Exception as e:
cexc.log_traceback()
raise RuntimeError(
'Error while fitting "%s" model: %s' % (algo_options['algo_name'], str(e))
)
has_applied = isinstance(prediction_df, pd.DataFrame)
if has_applied:
df = prediction_df
return algo, df, has_applied
def get_output(self):
"""Override get_output from BaseProcessor.
Check if prediction was already made, otherwise make prediction.
Returns:
(dataframe): output dataframe
"""
if not is_empty_df(self.df):
if not self.has_applied:
try:
self.df = self.algo.apply(self.df, self.algo_options)
except Exception as e:
cexc.log_traceback()
logger.debug(
'Error during apply phase of fit command. Check apply method of algorithm.'
)
raise RuntimeError(
'Error while fitting "%s" model: %s'
% (self.algo_options['algo_name'], str(e))
)
if self.df is None:
messages.warn('Apply method did not return any results.')
self.df = pd.DataFrame()
else:
messages.warning("The dataset has no events.")
return self.df