You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
73 lines
2.7 KiB
73 lines
2.7 KiB
#!/usr/bin/env python
|
|
from base import BaseAlgo, TransformerMixin
|
|
from codec import codecs_manager
|
|
from util.param_util import convert_params
|
|
import pandas as pd
|
|
import networkx as nx
|
|
from cexc import get_messages_logger, get_logger
|
|
debug = get_messages_logger()
|
|
|
|
class GraphCentrality(TransformerMixin, BaseAlgo):
|
|
|
|
def __init__(self, options):
|
|
debug.info('NetworkX Version {}'.format( nx.__version__))
|
|
self.handle_options(options)
|
|
out_params = convert_params(
|
|
options.get('params', {}),
|
|
strs=['compute'],
|
|
ints=['max_iter']
|
|
)
|
|
if 'max_iter' not in out_params:
|
|
options['max_iter']=1000
|
|
else:
|
|
options['max_iter'] = out_params['max_iter']
|
|
if 'compute' not in out_params:
|
|
options['compute']='degree_centrality'
|
|
else:
|
|
options['compute'] = out_params['compute']
|
|
|
|
# call same logic as in fit
|
|
def apply(self, df, options):
|
|
return self.fit(df, options)
|
|
|
|
# compute centrality scores
|
|
def fit(self, df, options):
|
|
# Make a copy of data, to not alter original dataframe
|
|
X = df.copy()
|
|
|
|
# create the graph
|
|
graph = nx.Graph()
|
|
src_dest_name = self.feature_variables
|
|
dfg = X[src_dest_name]
|
|
for index, row in dfg.iterrows():
|
|
graph.add_edge(row[src_dest_name[0]], row[src_dest_name[1]]) #, value=row['value'])
|
|
|
|
# compute centrality
|
|
algos = options["compute"].lstrip("\"").rstrip("\"").lower().split(',')
|
|
outputcolumns = []
|
|
for algo in algos:
|
|
if algo=='degree_centrality':
|
|
cents = nx.algorithms.centrality.degree_centrality(graph)
|
|
outputcolumns.append(algo)
|
|
elif algo=='betweenness_centrality':
|
|
cents = nx.algorithms.centrality.betweenness_centrality(graph)
|
|
outputcolumns.append(algo)
|
|
elif algo=='eigenvector_centrality':
|
|
cents = nx.algorithms.centrality.eigenvector_centrality(graph, max_iter=options["max_iter"])
|
|
outputcolumns.append(algo)
|
|
elif algo=='cluster_coefficient':
|
|
cents = nx.algorithms.cluster.clustering(graph)
|
|
outputcolumns.append(algo)
|
|
else:
|
|
continue
|
|
degs = pd.DataFrame(list(cents.items()), columns=[src_dest_name[0], algo])
|
|
X = X.join(degs.set_index(src_dest_name[0]), on=src_dest_name[0])
|
|
|
|
# return results
|
|
return X
|
|
|
|
@staticmethod
|
|
def register_codecs():
|
|
from codec.codecs import SimpleObjectCodec
|
|
codecs_manager.add_codec('mltk_graph.GraphCentrality', 'GraphCentrality', SimpleObjectCodec)
|