You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

59 lines
1.8 KiB

#!/usr/bin/env python
from sklearn.cluster import Birch as _Birch
from base import BaseAlgo, ClustererMixin
from codec import codecs_manager
from codec.codecs import BaseCodec
from codec.flatten import flatten, expand
from util import df_util
from util.param_util import convert_params
class BirchCodec(BaseCodec):
@classmethod
def encode(cls, obj):
"""Birch has circular references and must be flattened."""
flat_obj, refs = flatten(obj)
return {
'__mlspl_type': [type(obj).__module__, type(obj).__name__],
'dict': flat_obj.__dict__,
'refs': refs,
}
@classmethod
def decode(cls, obj):
import sklearn.cluster
m = sklearn.cluster._birch.Birch.__new__(sklearn.cluster._birch.Birch)
m.__dict__ = obj['dict']
return expand(m, obj['refs'])
class Birch(ClustererMixin, BaseAlgo):
def __init__(self, options):
self.handle_options(options)
out_params = convert_params(
options.get('params', {}), ints=['k'], aliases={'k': 'n_clusters'}
)
self.estimator = _Birch(**out_params)
def apply(self, df, options):
"""Apply is overriden to make prediction on chunks of 10000 rows."""
func = super(self.__class__, self).apply
return df_util.apply_in_chunks(df, func, 10000, options)
@staticmethod
def register_codecs():
from codec.codecs import SimpleObjectCodec
codecs_manager.add_codec('sklearn.cluster._birch', 'Birch', BirchCodec)
codecs_manager.add_codec('codec.flatten', 'Ref', SimpleObjectCodec)
codecs_manager.add_codec('algos.Birch', 'Birch', SimpleObjectCodec)
codecs_manager.add_codec('sklearn.cluster._birch', '_CFNode', SimpleObjectCodec)
codecs_manager.add_codec('sklearn.cluster._birch', '_CFSubcluster', SimpleObjectCodec)