可以将文章内容翻译成中文,广告屏蔽插件可能会导致该功能失效(如失效,请关闭广告屏蔽插件后再试):
问题:
Say part of my dataframe df[(df['person_num'] == 1) | (df['person_num'] == 2) ]
looks like this:
person_num Days IS_TRUE 1 1 1 1 4 1 1 5 0 1 9 1 2 1 1 2 4 1 2 5 0 2 9 1
And for each person_num
, I want to count something like "how many IS_TRUE=1
happens within seven days before a certain day". So for Day 9, I count the number of IS_TRUE=1
s from Day 2 to Day 8, and add the count to a new column IS_TRUE_7day_WINDOW
. The result would be:
person_num Days IS_TRUE IS_TRUE_7day_WINDOW 1 1 1 0 1 4 1 1 1 5 0 2 1 9 1 1 2 1 1 0 2 4 1 1 2 5 0 2 2 9 1 1
I'm thinking about using something like this:
df.groupby('person_num').transform(pd.rolling_sum, window=7,min_periods=1)
But I think rolling_sum only works for datetime, and the code doesn't work for my dataframe. Is there an easy way to convert rolling_sum
to work for integers (Days
in my case)? Or are there other ways to quickly compute the column I want?
I used for
loops to calculate IS_TRUE_7day_WINDOW
, but it took me an hour to get the results since my dataframe is pretty large. I guess something like rolling_sum
would speed up my old code.
回答1:
You could implicitly do the for
loop through vectorization, which will in general be faster than explicitly writing a for
loop. Here's a working example on the data you provided:
import pandas as pd import numpy as np df = pd.DataFrame({'Days': [1, 4, 5, 9, 1, 4, 5, 9], 'IS_TRUE': [1, 1, 0, 1, 1, 1, 0, 1], 'person_num': [1, 1, 1, 1, 2, 2, 2, 2]}) def window(group): diff = np.subtract.outer(group.Days, group.Days) group['IS_TRUE_7day_WINDOW'] = np.dot((diff > 0) & (diff <= 7), group['IS_TRUE']) return group f.groupby('person_num').apply(window)
Output is this:
Days IS_TRUE person_num IS_TRUE_7day_WINDOW 0 1 1 1 0 1 4 1 1 1 2 5 0 1 2 3 9 1 1 1 4 1 1 2 0 5 4 1 2 1 6 5 0 2 2 7 9 1 2 1
回答2:
Since you mentioned data frame derives from a database, consider an SQL solution using a subquery which runs the calculation in its engine and not directly in Python.
Below assumes a MySQL database but adjust library and connection string according to your actual backend (SQLite, PostgreSQL, SQL Server, etc.). Below should be ANSI-syntax SQL, compliant in most RDMS.
SQL Solution
import pandas pd import pymysql conn = pymysql.connect(host="localhost" port=3306, user="username", passwd="***", db="databasename") sql = "SELECT t1.Days, t1.person_num, t1.IS_TRUE, \ (SELECT IFNULL(SUM(t2.IS_TRUE),0) \ FROM TableName t2 \ WHERE t2.person_num= t1.person_num \ AND t2.Days >= t1.Days - 7 \ AND t2.Days < t1.Days) AS IS_TRUE_7DAY_WINDOW \ FROM TableName t1" df = pd.read_sql(sql, conn)
OUTPUT
Days person_num IS_TRUE IS_TRUE_7DAY_WINDOW 1 1 1 0 4 1 1 1 5 1 0 2 9 1 1 1 1 2 1 0 4 2 1 1 5 2 0 2 9 2 1 1
回答3:
The rolling_
functions like rolling_sum
use the index of the DataFrame or Series when seeing how far to go back. It doesn't have to be a datetime index. Below is some code to find the calculation for each user...
First use crosstab
to make a DataFrame with a column for each person_num
and a row for each day.
>>> days_person = pd.crosstab(data['days'], data['person_num'], values=data['is_true'], aggfunc=pd.np.sum) >>> days_person person_num 1 2 days 1 1 1 4 1 1 5 0 0 9 1 1
Next I'm going to fill in missing days with 0's, because you only have a few days of data.
>>> empty_data = {n: [0]*10 for n in days_person.columns} >>> days_person = (days_person + pd.DataFrame(empty_data)).fillna(0) >>> days_person person_num 1 2 days 1 1 1 2 0 0 3 0 0 4 1 1 5 0 0 6 0 0 7 0 0 8 0 0 9 1 1
Now use rolling_sum
to get the table you're looking for. Note that days 1-6 will have NaN
values, because there weren't enough previous days to do the calculation.
>>> pd.rolling_sum(days_person, 7)