You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
43 lines
1.4 KiB
43 lines
1.4 KiB
#!/usr/bin/env python
|
|
|
|
from sklearn.cluster import DBSCAN as _DBSCAN
|
|
|
|
import cexc
|
|
from base import BaseAlgo, ClustererMixin
|
|
from util import df_util
|
|
from util.param_util import convert_params
|
|
|
|
|
|
class DBSCAN(ClustererMixin, BaseAlgo):
|
|
def __init__(self, options):
|
|
self.handle_options(options)
|
|
out_params = convert_params(
|
|
options.get('params', {}), floats=['eps'], ints=['min_samples']
|
|
)
|
|
|
|
self.estimator = _DBSCAN(**out_params)
|
|
|
|
def fit(self, df, options):
|
|
# Make a copy of data, to not alter original dataframe
|
|
X = df.copy()
|
|
|
|
X, nans, _ = df_util.prepare_features(
|
|
X=X, variables=self.feature_variables, mlspl_limits=options.get('mlspl_limits')
|
|
)
|
|
if nans.any():
|
|
# If null values found in the data, warn the user to handle them before fit.
|
|
cexc.messages.warn(
|
|
"NULL values found in the dataset. Clusters are not assigned for these values currently. "
|
|
"Please consider handling null (or missing) entries to get appropriate clustering output."
|
|
)
|
|
y_hat = self.estimator.fit_predict(X.values)
|
|
|
|
default_name = 'cluster'
|
|
output_name = options.get('output_name', default_name)
|
|
|
|
output = df_util.create_output_dataframe(
|
|
y_hat=y_hat, nans=nans, output_names=output_name
|
|
)
|
|
df = df_util.merge_predictions(df, output)
|
|
return df
|