You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

63 lines
2.1 KiB

#!/usr/bin/env python
from base import BaseAlgo, TransformerMixin
from codec import codecs_manager
from util.param_util import convert_params
import pandas as pd
import numpy as np
import networkx as nx
from cexc import get_messages_logger, get_logger
debug = get_messages_logger()
class MinimumSpanningTree(TransformerMixin, BaseAlgo):
def __init__(self, options):
debug.info('NetworkX Version {}'.format( nx.__version__))
self.handle_options(options)
out_params = convert_params(
options.get('params', {}),
strs=['weight']
)
if 'weight' not in out_params:
options['weight']='one'
else:
options['weight'] = out_params['weight']
# call same logic as in fit
def apply(self, df, options):
return self.fit(df, options)
# compute centrality scores
def fit(self, df, options):
# Make a copy of data, to not alter original dataframe
X = df.copy()
# create the graph
G = nx.Graph()
src_dest_name = self.feature_variables
X['join_key'] = X[src_dest_name[0]].apply(str)+'_'+X[src_dest_name[1]].apply(str)
if options['weight']=='one':
for index, row in X.iterrows():
G.add_edge(row[src_dest_name[0]],row[src_dest_name[1]], weight=1)
else:
for index, row in X.iterrows():
G.add_edge(row[src_dest_name[0]],row[src_dest_name[1]], weight=row[options['weight']])
T = nx.minimum_spanning_tree(G)
Y = pd.DataFrame(columns=['source','destination'])
for e in T.edges():
Y = Y.append({'source': e[0], 'destination': e[1]}, ignore_index=True)
Y['join_key'] = Y['source'].apply(str)+'_'+Y['destination'].apply(str)
X = pd.merge(X,Y['join_key'],on='join_key',how='inner')
X.drop('join_key', axis=1, inplace=True)
return X
@staticmethod
def register_codecs():
from codec.codecs import SimpleObjectCodec
codecs_manager.add_codec('mltk_graph.MinimumSpanningTree', 'MinimumSpanningTree', SimpleObjectCodec)