Simple logistic regression with Statsmodels: Adding an intercept and visualizing the logistic regression equation

假如想象 提交于 2020-05-16 05:54:09

问题


Using Statsmodels, I am trying to generate a simple logistic regression model to predict whether a person smokes or not (Smoke) based on their height (Hgt).

I have a feeling that an intercept needs to be included into the logistic regression model but I am not sure how to implement one using the add_constant() function. Also, I am unsure why the error below is generated.

This is the dataset, Pulse.CSV: https://drive.google.com/file/d/1FdUK9p4Dub4NXsc-zHrYI-AGEEBkX98V/view?usp=sharing

The full code and output are in this PDF file: https://drive.google.com/file/d/1kHlrAjiU7QvFXF2a7tlTSFPgfpq9bOXJ/view?usp=sharing

import numpy as np
import pandas as pd
import statsmodels.api as sm
import matplotlib.pyplot as plt
import seaborn as sns
sns.set()
raw_data = pd.read_csv('Pulse.csv')
raw_data
x1 = raw_data['Hgt']
y = raw_data['Smoke'] 
reg_log = sm.Logit(y,x1,missing='Drop')
results_log = reg_log.fit()
def f(x,b0,b1):
    return np.array(np.exp(b0+x*b1) / (1 + np.exp(b0+x*b1)))
f_sorted = np.sort(f(x1,results_log.params[0],results_log.params[1]))
x_sorted = np.sort(np.array(x1))
plt.scatter(x1,y,color='C0')
plt.xlabel('Hgt', fontsize = 20)
plt.ylabel('Smoked', fontsize = 20)
plt.plot(x_sorted,f_sorted,color='C8')
plt.show()

---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
~/opt/anaconda3/lib/python3.7/site-packages/pandas/core/indexes/base.py in get_value(self, series, key)
   4729         try:
-> 4730             return self._engine.get_value(s, k, tz=getattr(series.dtype, "tz", None))
   4731         except KeyError as e1:
((( Truncated for brevity )))
IndexError: index out of bounds

回答1:


Intercept is not added by default in Statsmodels regression, but if you need you can include it manually.

import numpy as np
import pandas as pd
import statsmodels.api as sm
import matplotlib.pyplot as plt
import seaborn as sns
sns.set()
raw_data = pd.read_csv('Pulse.csv')
raw_data
x1 = raw_data['Hgt']
y = raw_data['Smoke'] 

x1 = sm.add_constant(x1)

reg_log = sm.Logit(y,x1,missing='Drop')
results_log = reg_log.fit()

results_log.summary()

def f(x,b0,b1):
    return np.array(np.exp(b0+x*b1) / (1 + np.exp(b0+x*b1)))
f_sorted = np.sort(f(x1,results_log.params[0],results_log.params[1]))
x_sorted = np.sort(np.array(x1))

plt.scatter(x1['Hgt'],y,color='C0')

plt.xlabel('Hgt', fontsize = 20)
plt.ylabel('Smoked', fontsize = 20)
plt.plot(x_sorted,f_sorted,color='C8')
plt.show()

This will also resolve the error as there was no intercept in your initial code.Source



来源:https://stackoverflow.com/questions/61560569/simple-logistic-regression-with-statsmodels-adding-an-intercept-and-visualizing

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!