You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

73 lines
2.7 KiB

#!/usr/bin/env python
from base import BaseAlgo, TransformerMixin
from codec import codecs_manager
from util.param_util import convert_params
import pandas as pd
import networkx as nx
from cexc import get_messages_logger, get_logger
debug = get_messages_logger()
class GraphCentrality(TransformerMixin, BaseAlgo):
def __init__(self, options):
debug.info('NetworkX Version {}'.format( nx.__version__))
self.handle_options(options)
out_params = convert_params(
options.get('params', {}),
strs=['compute'],
ints=['max_iter']
)
if 'max_iter' not in out_params:
options['max_iter']=1000
else:
options['max_iter'] = out_params['max_iter']
if 'compute' not in out_params:
options['compute']='degree_centrality'
else:
options['compute'] = out_params['compute']
# call same logic as in fit
def apply(self, df, options):
return self.fit(df, options)
# compute centrality scores
def fit(self, df, options):
# Make a copy of data, to not alter original dataframe
X = df.copy()
# create the graph
graph = nx.Graph()
src_dest_name = self.feature_variables
dfg = X[src_dest_name]
for index, row in dfg.iterrows():
graph.add_edge(row[src_dest_name[0]], row[src_dest_name[1]]) #, value=row['value'])
# compute centrality
algos = options["compute"].lstrip("\"").rstrip("\"").lower().split(',')
outputcolumns = []
for algo in algos:
if algo=='degree_centrality':
cents = nx.algorithms.centrality.degree_centrality(graph)
outputcolumns.append(algo)
elif algo=='betweenness_centrality':
cents = nx.algorithms.centrality.betweenness_centrality(graph)
outputcolumns.append(algo)
elif algo=='eigenvector_centrality':
cents = nx.algorithms.centrality.eigenvector_centrality(graph, max_iter=options["max_iter"])
outputcolumns.append(algo)
elif algo=='cluster_coefficient':
cents = nx.algorithms.cluster.clustering(graph)
outputcolumns.append(algo)
else:
continue
degs = pd.DataFrame(list(cents.items()), columns=[src_dest_name[0], algo])
X = X.join(degs.set_index(src_dest_name[0]), on=src_dest_name[0])
# return results
return X
@staticmethod
def register_codecs():
from codec.codecs import SimpleObjectCodec
codecs_manager.add_codec('mltk_graph.GraphCentrality', 'GraphCentrality', SimpleObjectCodec)