Compute winning streak with pandas

空扰寡人 提交于 2020-01-13 18:10:33

问题


I thought I knew how to do this but I'm pulling my hair out over it. I'm trying to use a function to create a new column. The function looks at the value of the win column in the current row and needs to compare it to the previous number in the win column as the if statements lay out below. The win column will only ever be 0 or 1.

import pandas as pd
data = pd.DataFrame({'win': [0, 0, 1, 1, 1, 0, 1]})
print (data)

   win
0    0
1    0
2    1
3    1
4    1
5    0
6    1

def streak(row):
    win_current_row = row['win']
    win_row_above = row['win'].shift(-1)
    streak_row_above = row['streak'].shift(-1)

    if (win_row_above == 0) & (win_current_row == 0):
        return 0
    elif (win_row_above == 0) & (win_current_row ==1):
        return 1
    elif (win_row_above ==1) & (win_current_row == 1):
        return streak_row_above + 1
    else:
        return 0

data['streak'] = data.apply(streak, axis=1)

All this ends with this error:

AttributeError: ("'numpy.int64' object has no attribute 'shift'", 'occurred at index 0')

In other examples I see functions that are referring to df['column'].shift(1) so I'm confused why I can't seem to do it in this instance.

The output I'm trying to get too is:

result = pd.DataFrame({'win': [0, 0, 1, 1, 1, 0, 1], 'streak': ['NaN', 0 , 1, 2, 3, 0, 1]})
print(result)

   win streak
0    0    NaN
1    0      0 
2    1      1
3    1      2
4    1      3
5    0      0
6    1      1

Thanks for helping to get me unstuck.


回答1:


A fairly common trick when using pandas is grouping by consecutive values. This trick is well-described here.

To solve your particular problem, we want to groupby consecutive values, and then use cumsum, which means that groups of losses (groups of 0) will have a cumulative sum of 0, while groups of wins (or groups of 1) will track winning streaks.

grouper = (df.win != df.win.shift()).cumsum()
df['streak'] = df.groupby(grouper).cumsum()

   win  streak
0    0       0
1    0       0
2    1       1
3    1       2
4    1       3
5    0       0
6    1       1

For the sake of explanation, here is our grouper Series, which allows us to group by continuous regions of 1's and 0's:

print(grouper)

0    1
1    1
2    2
3    2
4    2
5    3
6    4
Name: win, dtype: int64



回答2:


Let's try groupby and cumcount:

m = df.win.astype(bool)
df['streak'] = (
    m.groupby([m, (~m).cumsum().where(m)]).cumcount().add(1).mul(m))

df
   win  streak
0    0       0
1    0       0
2    1       1
3    1       2
4    1       3
5    0       0
6    1       1

How it Works

Using df.win.astype(bool), convert df['win'] to its boolean equivalent (1=True, 0=False).

Next,

(~m).cumsum().where(m)

0    NaN
1    NaN
2    2.0
3    2.0
4    2.0
5    NaN
6    3.0
Name: win, dtype: float64

Represents all contiguous 1s with a unique number, with 0s being masked as NaN.

Now, use groupby, and cumcount to assign each row in the group with a monotonically increasing number.

m.groupby([m, (~m).cumsum().where(m)]).cumcount()

0    0
1    1
2    0
3    1
4    2
5    2
6    0
dtype: int64

This is what we want but you can see it is 1) zero-based, and 2) also assigns values to the 0 (no win). We can use m to mask it (x times 1 (=True) is x, and anything times 0 (=False) is 0).

m.groupby([m, (~m).cumsum().where(m)]).cumcount().add(1).mul(m)

0    0
1    0
2    1
3    2
4    3
5    0
6    1
dtype: int64

Assign this back in-place.




回答3:


The reason why your getting that error is because shift() is pandas method. What your code was trying to do was getting the value at the in the row (row['win']) which is of numpy.int64. So you where trying to perform shift() on a numpy.int64. What this df['column'].shift(1) does is takes a dateframe column which is also a dataframe and shifts that column by 1.

To test this for yourself try print(type(data['win'])) and print(type(row['win'])) and print(type(row))

That will tell you the datatype.

also your going to get an error when you get to
streak_row_above = row['streak'].shift(-1)

because your referring to row['streak'] before it is created.



来源:https://stackoverflow.com/questions/52976336/compute-winning-streak-with-pandas

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!