You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
59 lines
1.8 KiB
59 lines
1.8 KiB
#!/usr/bin/env python
|
|
|
|
from sklearn.cluster import Birch as _Birch
|
|
|
|
from base import BaseAlgo, ClustererMixin
|
|
from codec import codecs_manager
|
|
from codec.codecs import BaseCodec
|
|
from codec.flatten import flatten, expand
|
|
from util import df_util
|
|
from util.param_util import convert_params
|
|
|
|
|
|
class BirchCodec(BaseCodec):
|
|
@classmethod
|
|
def encode(cls, obj):
|
|
"""Birch has circular references and must be flattened."""
|
|
flat_obj, refs = flatten(obj)
|
|
|
|
return {
|
|
'__mlspl_type': [type(obj).__module__, type(obj).__name__],
|
|
'dict': flat_obj.__dict__,
|
|
'refs': refs,
|
|
}
|
|
|
|
@classmethod
|
|
def decode(cls, obj):
|
|
import sklearn.cluster
|
|
|
|
m = sklearn.cluster._birch.Birch.__new__(sklearn.cluster._birch.Birch)
|
|
m.__dict__ = obj['dict']
|
|
|
|
return expand(m, obj['refs'])
|
|
|
|
|
|
class Birch(ClustererMixin, BaseAlgo):
|
|
def __init__(self, options):
|
|
self.handle_options(options)
|
|
|
|
out_params = convert_params(
|
|
options.get('params', {}), ints=['k'], aliases={'k': 'n_clusters'}
|
|
)
|
|
|
|
self.estimator = _Birch(**out_params)
|
|
|
|
def apply(self, df, options):
|
|
"""Apply is overriden to make prediction on chunks of 10000 rows."""
|
|
func = super(self.__class__, self).apply
|
|
return df_util.apply_in_chunks(df, func, 10000, options)
|
|
|
|
@staticmethod
|
|
def register_codecs():
|
|
from codec.codecs import SimpleObjectCodec
|
|
|
|
codecs_manager.add_codec('sklearn.cluster._birch', 'Birch', BirchCodec)
|
|
codecs_manager.add_codec('codec.flatten', 'Ref', SimpleObjectCodec)
|
|
codecs_manager.add_codec('algos.Birch', 'Birch', SimpleObjectCodec)
|
|
codecs_manager.add_codec('sklearn.cluster._birch', '_CFNode', SimpleObjectCodec)
|
|
codecs_manager.add_codec('sklearn.cluster._birch', '_CFSubcluster', SimpleObjectCodec)
|