You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
80 lines
2.2 KiB
80 lines
2.2 KiB
#!/usr/bin/env python
|
|
|
|
import pandas as pd
|
|
import numpy as np
|
|
|
|
|
|
class ReservoirSampler(object):
|
|
def __init__(self, limit, random_state=None):
|
|
self.limit = limit
|
|
self._dfs = []
|
|
self._dfs_count = 0
|
|
self._df = pd.DataFrame()
|
|
self.count = 0
|
|
|
|
if random_state:
|
|
np.random.seed(random_state)
|
|
self.random_state = np.random.get_state()
|
|
|
|
self._max_dfs_len = self.limit / 5
|
|
|
|
def _concat_df(self, df):
|
|
self._dfs.append(df)
|
|
self._dfs_count += len(df)
|
|
|
|
self._max_dfs_len = max(self._max_dfs_len, len(df))
|
|
|
|
def _combine_dfs(self):
|
|
if len(self._dfs) > 0:
|
|
self._df = self._df.append(self._dfs)
|
|
self._df.drop_duplicates(subset='_slot', keep='last', inplace=True)
|
|
self._dfs = []
|
|
self._dfs_count = 0
|
|
|
|
def get_df(self):
|
|
if self.count == 0 and len(self._dfs) == 0:
|
|
return self._df
|
|
|
|
self._combine_dfs()
|
|
self._df.reset_index(drop=True, inplace=True)
|
|
|
|
return self._df.sort_values('_gindex').drop(['_gindex', '_slot'], axis=1)
|
|
|
|
def append(self, new_df, copy=False):
|
|
if len(new_df) == 0:
|
|
return
|
|
|
|
if copy:
|
|
new_df = new_df.copy()
|
|
|
|
# Assign counter to new_df
|
|
new_df['_gindex'] = np.arange(len(new_df)) + self.count
|
|
self.count += len(new_df)
|
|
|
|
if self.limit <= 0 or self.count <= self.limit:
|
|
new_df['_slot'] = new_df['_gindex']
|
|
self._concat_df(new_df)
|
|
return
|
|
|
|
# Move the head of new_df to self._dfs
|
|
if self.count - len(new_df) < self.limit:
|
|
head_count = self.limit - (self.count - len(new_df))
|
|
new_df['_slot'] = new_df['_gindex']
|
|
self._concat_df(new_df[0:head_count])
|
|
new_df = new_df[head_count:].copy()
|
|
|
|
np.random.set_state(self.random_state)
|
|
rnd = np.random.rand(len(new_df))
|
|
self.random_state = np.random.get_state()
|
|
|
|
new_df['_slot'] = np.floor(rnd * new_df['_gindex']).astype(int)
|
|
keepers = new_df[new_df['_slot'] < self.limit]
|
|
new_df = None
|
|
|
|
self._concat_df(keepers)
|
|
|
|
if self._dfs_count > self._max_dfs_len:
|
|
self._combine_dfs()
|
|
|
|
return
|