#!/usr/bin/env python # Copyright (C) 2015-2019 Splunk Inc. All Rights Reserved. import pandas as pd import cexc from .BaseProcessor import BaseProcessor import models.base from models import deletemodels from util.base_util import match_field_globs from util.base_util import MLSPLNotImplementedError from util.df_util import is_empty_df from util.algos import initialize_algo_class from util.mlspl_loader import MLSPLConf from util.telemetry_util import log_algo_details from base import ClassifierMixin, RegressorMixin from util.processor_util import ( split_options, load_resource_limits, load_sampler_limits, get_sampler, check_sampler, ) logger = cexc.get_logger(__name__) messages = cexc.get_messages_logger() class FitBatchProcessor(BaseProcessor): """The fit batch processor receives and returns pandas DataFrames.""" def __init__(self, process_options, searchinfo): """Initialize options for processor. Args: process_options (dict): process options searchinfo (dict): information required for search """ # Split apart process & algo options self.namespace = process_options.pop('namespace', None) mlspl_conf = MLSPLConf(searchinfo) self.process_options, self.algo_options = split_options( process_options, mlspl_conf, process_options['algo_name'] ) self.searchinfo = searchinfo # Convenience / readability self.tmp_dir = self.process_options['tmp_dir'] self.algo = self.initialize_algo(self.algo_options, self.searchinfo) self.check_algo_options(self.algo, self.algo_options) self.save_temp_model(self.algo_options, self.tmp_dir) self.resource_limits = load_resource_limits(self.algo_options['algo_name'], mlspl_conf) self._sampler_time = 0.0 self.sampler_limits = load_sampler_limits( self.process_options, self.algo_options['algo_name'], mlspl_conf ) self.sampler = get_sampler(self.sampler_limits) @staticmethod def initialize_algo(algo_options, searchinfo): algo_name = algo_options['algo_name'] try: algo_class = initialize_algo_class(algo_name, searchinfo) return algo_class(algo_options) except Exception as e: cexc.log_traceback() raise RuntimeError( 'Error while initializing algorithm "%s": %s' % (algo_name, str(e)) ) @staticmethod def check_algo_options(algo, algo_options): """Raise errors if options are incompatible Args: algo (object): initialized algo object algo_options (dict): algo options Raises: RuntimeError """ # Pre-validate whether or not this algo supports saved models. if 'model_name' in algo_options: if algo_options.get('kfold_cv') is not None: raise RuntimeError('The kfold_cv option cannot be used when saving a model') try: algo.register_codecs() except MLSPLNotImplementedError: raise RuntimeError( 'Algorithm "%s" does not support saved models' % algo_options['algo_name'] ) except Exception as e: logger.debug( "Error while calling algorithm's register_codecs method. {}".format(str(e)) ) raise RuntimeError( 'Error while initializing algorithm. See search.log for details.' ) can_use_kfold_cv = isinstance(algo, (ClassifierMixin, RegressorMixin)) if algo_options.get('kfold_cv') is not None: if not can_use_kfold_cv: raise RuntimeError( 'Algorithm "%s" does not support the kfold_cv parameter' % algo_options['algo_name'] ) if algo_options.get('kfold_cv') <= 1: raise RuntimeError('kfold_cv must be > 1') @staticmethod def match_and_assign_variables(app_name, columns, algo, algo_options): """Match field globs and attach variables to algo. Args: app_name (str): application name which runs the fit() columns (list): columns from dataframe algo (object): initialized algo object algo_options (dict): algo options """ if hasattr(algo, 'feature_variables'): algo.feature_variables = match_field_globs(columns, algo.feature_variables) log_algo_details(app_name, algo, algo_options) else: algo.feature_variables = [] # Batch fit if 'target_variable' in algo_options: target_variable = algo_options['target_variable'][0] if target_variable in algo.feature_variables: algo.feature_variables.remove(target_variable) # Partial fit elif hasattr(algo, 'target_variable'): if algo.target_variable in algo.feature_variables: algo.feature_variables.remove(algo.target_variable) return algo @staticmethod def save_temp_model(algo_options, tmp_dir): """Save temp model for follow-up apply. Args: algo_options (dict): algo options tmp_dir (str): temp directory to save model to """ if 'model_name' in algo_options: try: models.base.save_model( algo_options['model_name'], None, algo_options['algo_name'], algo_options, model_dir=tmp_dir, tmp=True, ) except Exception as e: cexc.log_traceback() raise RuntimeError( 'Error while saving temporary model "%s": %s' % (algo_options['model_name'], e) ) def get_relevant_fields(self): """Ask algo for relevant variables and return as relevant fields. Returns: relevant_fields (list): relevant fields """ relevant_fields = [] if 'feature_variables' in self.algo_options: self.algo.feature_variables = self.algo_options['feature_variables'] relevant_fields.extend(self.algo_options['feature_variables']) if 'target_variable' in self.algo_options: self.algo.target_variable = self.algo_options['target_variable'][0] relevant_fields.extend(self.algo_options['target_variable']) if 'split_by' in self.algo_options: relevant_fields.extend(self.algo_options['split_by']) return relevant_fields def save_model(self): """Attempt to save the model, delete the temporary model.""" if not is_empty_df(self.df) and 'model_name' in self.algo_options: try: models.base.save_model( self.algo_options['model_name'], self.algo, self.algo_options['algo_name'], self.algo_options, max_size=self.resource_limits['max_model_size_mb'], searchinfo=self.searchinfo, namespace=self.namespace, ) except Exception as e: cexc.log_traceback() raise RuntimeError( 'Error while saving model "%s": %s' % (self.algo_options['model_name'], e) ) try: deletemodels.delete_model( self.algo_options['model_name'], model_dir=self.tmp_dir, tmp=True ) except Exception as e: cexc.log_traceback() logger.warn( 'Exception while deleting tmp model "%s": %s', self.algo_options['model_name'], e, ) def receive_input(self, df): """Receive dataframe and append to sampler if necessary. Args: df (dataframe): dataframe received from controller """ if ( self.sampler_limits['sample_count'] - len(df) < self.sampler.count <= self.sampler_limits['sample_count'] ): check_sampler( sampler_limits=self.sampler_limits, class_name=self.algo_options['algo_name'] ) with cexc.Timer() as sampler_t: self.sampler.append(df) self._sampler_time += sampler_t.interval logger.debug('sampler_time=%f', sampler_t.interval) def process(self): """Get dataframe, update algo, and possibly make predictions.""" self.df = self.sampler.get_df() if not is_empty_df(self.df): self.algo = FitBatchProcessor.match_and_assign_variables( self.searchinfo.get('app'), self.df.columns, self.algo, self.algo_options ) self.algo, self.df, self.has_applied = self.fit( self.df, self.algo, self.algo_options ) @staticmethod def fit(df, algo, algo_options): """Perform the literal fitting process. This method updates the algo by fitting with input data. Some of the algorithms additionally make predictions within their fit method, thus the predictions are returned in dataframe type. Some other algorithms do not make prediction in their fit method, thus None is returned. Args: df (dataframe): dataframe to fit the algo algo (object): initialized/loaded algo object algo_options (dict): algo options Returns: algo (object): updated algo object df (dataframe): - if algo.fit makes prediction, return prediction - if algo.fit does not make prediction, return input df has_applied (bool): flag to indicate whether df represents original df or prediction df """ has_applied = False if not is_empty_df(df): try: prediction_df = algo.fit(df, algo_options) except Exception as e: cexc.log_traceback() raise RuntimeError( 'Error while fitting "%s" model: %s' % (algo_options['algo_name'], str(e)) ) has_applied = isinstance(prediction_df, pd.DataFrame) if has_applied: df = prediction_df return algo, df, has_applied def get_output(self): """Override get_output from BaseProcessor. Check if prediction was already made, otherwise make prediction. Returns: (dataframe): output dataframe """ if not is_empty_df(self.df): if not self.has_applied: try: self.df = self.algo.apply(self.df, self.algo_options) except Exception as e: cexc.log_traceback() logger.debug( 'Error during apply phase of fit command. Check apply method of algorithm.' ) raise RuntimeError( 'Error while fitting "%s" model: %s' % (self.algo_options['algo_name'], str(e)) ) if self.df is None: messages.warn('Apply method did not return any results.') self.df = pd.DataFrame() else: messages.warning("The dataset has no events.") return self.df