You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
50 lines
1.6 KiB
50 lines
1.6 KiB
#!/usr/bin/env python
|
|
from base import BaseAlgo, TransformerMixin
|
|
from codec import codecs_manager
|
|
from util.param_util import convert_params
|
|
import pandas as pd
|
|
import networkx as nx
|
|
from cexc import get_messages_logger, get_logger
|
|
debug = get_messages_logger()
|
|
|
|
class GraphConnectedComponents(TransformerMixin, BaseAlgo):
|
|
|
|
def __init__(self, options):
|
|
debug.info('NetworkX Version {}'.format( nx.__version__))
|
|
self.handle_options(options)
|
|
|
|
# call same logic as in fit
|
|
def apply(self, df, options):
|
|
return self.fit(df, options)
|
|
|
|
# compute connected components
|
|
def fit(self, df, options):
|
|
# Make a copy of data, to not alter original dataframe
|
|
X = df.copy()
|
|
|
|
# create a directed graph
|
|
graph = nx.DiGraph()
|
|
src_dest_name = self.feature_variables
|
|
dfg = X[src_dest_name]
|
|
for index, row in dfg.iterrows():
|
|
graph.add_edge(row[src_dest_name[0]], row[src_dest_name[1]]) #, value=row['value'])
|
|
|
|
# compute connected components
|
|
comps = nx.algorithms.components.weakly_connected_components(graph)
|
|
d = dict()
|
|
i = 0
|
|
for x in comps:
|
|
i=i+1
|
|
for n in x:
|
|
d[n]=i
|
|
# join connected components on first key
|
|
X = df.join(pd.DataFrame.from_dict(d, orient='index', columns=['connected_component']), on=src_dest_name[0])
|
|
|
|
# return results
|
|
return X
|
|
|
|
@staticmethod
|
|
def register_codecs():
|
|
from codec.codecs import SimpleObjectCodec
|
|
codecs_manager.add_codec('mltk_graph.GraphConnectedComponents', 'GraphConnectedComponents', SimpleObjectCodec)
|