2017-02-15 4 views
4

私は以下のパンダのデータフレームを持っています。条件を満たす十分な数の行を持たないグループをフィルタリングします。

import pandas as pd 

# Initialize dataframe 
df1 = pd.DataFrame(columns=['bar', 'foo']) 
df1['bar'] = ['001', '001', '001', '001', '002', '002', '003', '003', '003'] 
df1['foo'] = [-4, -3, 2, 3, -3, -2, 0, 1, 2] 
>>> print df1 
    bar foo 
0 001 -4 
1 001 -3 
2 001 2 
3 001 3 
4 002 -3 
5 002 -2 
6 003 0 
7 003 1 
8 003 2 

次のしきい値とパラメータを考慮してください。

# Provide threshold and number of entries above and below threshold 
threshold = 0 
n_below = 2 
n_above = 2 

私はbarの特定の値をフィルタリングするデータフレームを作成したいと思います。 barのフィルタリングしたいのは、少なくともn_belowの値がfoo未満で、thresholdn_aboveの値がfooより小さく、thresholdより大きい場合です。上記例えば

  • bar = 001ためthreshold = 0未満fooの少なくともn_below = 2エントリとthreshold = 0より大きいfooの少なくともn_above = 2エントリがあるので、基bar = 001は、除外されないであろう。
  • グループbar = 002は、bar = 002の場合、少なくともn_above = 2のエントリがthreshold = 0より大きいため、除外されます。
  • グループbar = 003は、bar = 003の場合、少なくともn_below = 2のエントリがfoo,threshold = 0未満であるため、フィルタリングされます。

所望の出力は、次のようになります。

# Desired output 
    bar foo 
0 001 -4 
1 001 -3 
2 001 2 
3 001 3 

私は、これはGROUPBYと.count()を用いて達成することができると信じて、しかし私は機能するソリューションを得ることができませんでした。私はそれが2つのステップでこれを行う解決策を書くことがよりきれいかもしれないことを認識している:1)n_belowの条件を満たすための最初のフィルター; 2)その後、n_aboveの条件を満たすようにフィルタリングします。

答えて

3

groupbyfilterの方法を使用できます。

threshold = 0 
n_below = 2 
n_above = 2 
def filter_function(g): 
    '''Called by filter, g is the grouped dataframe''' 
    l = g['foo'] 
    return (sum([x < threshold for x in l]) >= n_below 
      and sum([x > threshold for x in l]) >= n_above) 

df.groupby('bar').filter(filter_function) 

# gives 
    bar foo 
0 1 -4 
1 1 -3 
2 1 2 
3 1 3 

は私が解決策の一つが考える

2

Pandas: Filtration参照してください:

threshold = 1 
n_below = 2 
n_above = 2 

df1.set_index('bar').loc[ \ 
    df1.groupby('bar')\ 
     .apply(lambda df_sub: \ 
        (df_sub['foo']<threshold).sum()>=n_below \ 
       and (df_sub['foo']>threshold).sum()>=n_above)] \ 
.reset_index('bar') 

、それは

bar foo 
0 001 -4 
1 001 -3 
2 001 2 
3 001 3 
1
idx = df1.groupby('bar').apply(lambda x: (sum(x['foo'] < threshold) >= n_below) & (sum(x['foo'] > threshold) >= n_above)) 

print df1.set_index('bar')[idx].reset_index() 

    bar foo 
0 001 -4 
1 001 -3 
2 001 2 
3 001 3 
を返します。
関連する問題