Calculating multinomial logit model prediction probabilities

試著忘記壹切 提交于 2020-02-22 22:41:31

问题


Please try to give parameterize solution (there are more than three alternatives).

I have a dict with beta values:

{'B_X1': 2.0, 'B_X2': -3.0}

And this data frame:

 X1_123  X1_456  X1_789  X2_123  X2_456  X2_789
   6.75    4.69    9.59    5.52    9.69    7.40
   7.46    4.94    3.01    1.78    1.38    4.68
   2.05    7.30    4.08    7.02    8.24    8.49
   5.60    7.88    8.11    5.98    4.60    1.39
   1.80    8.28    9.16    7.34    7.69    6.16
   3.73    6.93    8.93    2.58    3.48    6.04
   8.06    8.88    7.06    6.76    4.68    7.82
   5.00    7.29    5.86    3.92    5.67    4.10
   2.49    2.55    4.66    7.15    6.26    7.87
   1.50    3.35    5.70    9.86    4.83    1.17
   8.19    7.72    9.56    6.61    4.15    3.64
   2.43    9.54    9.15    4.41    9.18    7.85
   2.71    3.24    4.56    6.22    7.89    9.93
   5.96    4.34    5.26    8.63    9.81    9.40

123, 456, and 789 are the alternatives.

I want to calculate the prediction probability using this formula:

j, k, and s are the mentioned alternatives.

Expected result:

 X1_123  X1_456  X1_789  X2_123  X2_456  X2_789  P_123  P_456  P_789
   6.75    4.69    9.59    5.52    9.69    7.40  0.490  0.000  0.510
   7.46    4.94    3.01    1.78    1.38    4.68  0.979  0.021  0.000
   2.05    7.30    4.08    7.02    8.24    8.49  0.001  0.998  0.001
   5.60    7.88    8.11    5.98    4.60    1.39  0.000  0.000  1.000
   1.80    8.28    9.16    7.34    7.69    6.16  0.000  0.002  0.998
   3.73    6.93    8.93    2.58    3.48    6.04  0.024  0.952  0.024
   8.06    8.88    7.06    6.76    4.68    7.82  0.000  1.000  0.000
   5.00    7.29    5.86    3.92    5.67    4.10  0.210  0.107  0.683
   2.49    2.55    4.66    7.15    6.26    7.87  0.038  0.623  0.339
   1.50    3.35    5.70    9.86    4.83    1.17  0.000  0.000  1.000
   8.19    7.72    9.56    6.61    4.15    3.64  0.000  0.005  0.995
   2.43    9.54    9.15    4.41    9.18    7.85  0.041  0.037  0.922
   2.71    3.24    4.56    6.22    7.89    9.93  0.981  0.019  0.001
   5.96    4.34    5.26    8.63    9.81    9.40  0.975  0.001  0.024

Probabilities sum should be 1 in every row.

Please try to give parameterize solution (there are more than three alternatives).

Expected result with constant for each alternative: {'B_X1': 2.0, 'B_X2': -3.0, 'B_123': 0.1, 'B_456': 0.2, 'B_789': 0.3}

 X1_123  X1_456  X1_789  X2_123  X2_456  X2_789  P_123  P_456  P_789
   6.75    4.69    9.59    5.52    9.69    7.40  0.440  0.000  0.560
   7.46    4.94    3.01    1.78    1.38    4.68  0.977  0.023  0.000
   2.05    7.30    4.08    7.02    8.24    8.49  0.001  0.998  0.001
   5.60    7.88    8.11    5.98    4.60    1.39  0.000  0.000  1.000
   1.80    8.28    9.16    7.34    7.69    6.16  0.000  0.002  0.998
   3.73    6.93    8.93    2.58    3.48    6.04  0.021  0.952  0.027
   8.06    8.88    7.06    6.76    4.68    7.82  0.000  1.000  0.000
   5.00    7.29    5.86    3.92    5.67    4.10  0.180  0.102  0.717
   2.49    2.55    4.66    7.15    6.26    7.87  0.034  0.604  0.363
   1.50    3.35    5.70    9.86    4.83    1.17  0.000  0.000  1.000
   8.19    7.72    9.56    6.61    4.15    3.64  0.000  0.005  0.995
   2.43    9.54    9.15    4.41    9.18    7.85  0.034  0.034  0.932
   2.71    3.24    4.56    6.22    7.89    9.93  0.978  0.021  0.001
   5.96    4.34    5.26    8.63    9.81    9.40  0.970  0.001  0.029

回答1:


IIUC:

Turn columns into a MultiIndex

df = df.set_axis(df.columns.str.split('_', expand=True), axis=1, inplace=False)

And define your B such that the keys match the prefixes in df

B = {'X1': 2.0, 'X2': -3.0}

Then

def f(b, x):
    return np.exp((b * x).sum(1))

parts = f(B, df.stack()).unstack()

preds = parts.div(parts.sum(1), axis=0)

df.join(pd.concat({'P': preds}, axis=1).round(3)).pipe(
    lambda d: d.set_axis(map('_'.join, d.columns), axis=1, inplace=False)
)

    X1_123  X1_456  X1_789  X2_123  X2_456  X2_789  P_123  P_456  P_789
0     6.75    4.69    9.59    5.52    9.69    7.40  0.490  0.000  0.510
1     7.46    4.94    3.01    1.78    1.38    4.68  0.979  0.021  0.000
2     2.05    7.30    4.08    7.02    8.24    8.49  0.001  0.998  0.001
3     5.60    7.88    8.11    5.98    4.60    1.39  0.000  0.000  1.000
4     1.80    8.28    9.16    7.34    7.69    6.16  0.000  0.002  0.998
5     3.73    6.93    8.93    2.58    3.48    6.04  0.024  0.952  0.024
6     8.06    8.88    7.06    6.76    4.68    7.82  0.000  1.000  0.000
7     5.00    7.29    5.86    3.92    5.67    4.10  0.210  0.107  0.683
8     2.49    2.55    4.66    7.15    6.26    7.87  0.038  0.623  0.339
9     1.50    3.35    5.70    9.86    4.83    1.17  0.000  0.000  1.000
10    8.19    7.72    9.56    6.61    4.15    3.64  0.000  0.005  0.995
11    2.43    9.54    9.15    4.41    9.18    7.85  0.041  0.037  0.922
12    2.71    3.24    4.56    6.22    7.89    9.93  0.981  0.019  0.001
13    5.96    4.34    5.26    8.63    9.81    9.40  0.975  0.001  0.024

Wrapped in one pretty function

def f(df, b):
    d = df.set_axis(df.columns.str.split('_', expand=True), axis=1, inplace=False)
    parts = np.exp(d.stack().mul(b).sum(1).unstack())
    preds = pd.concat({'P': parts.div(parts.sum(1), axis=0)}, axis=1).round(3)
    d = d.join(preds)
    d.columns = list(map('_'.join, d.columns))
    return d

f(df, B)

    X1_123  X1_456  X1_789  X2_123  X2_456  X2_789  P_123  P_456  P_789
0     6.75    4.69    9.59    5.52    9.69    7.40  0.490  0.000  0.510
1     7.46    4.94    3.01    1.78    1.38    4.68  0.979  0.021  0.000
2     2.05    7.30    4.08    7.02    8.24    8.49  0.001  0.998  0.001
3     5.60    7.88    8.11    5.98    4.60    1.39  0.000  0.000  1.000
4     1.80    8.28    9.16    7.34    7.69    6.16  0.000  0.002  0.998
5     3.73    6.93    8.93    2.58    3.48    6.04  0.024  0.952  0.024
6     8.06    8.88    7.06    6.76    4.68    7.82  0.000  1.000  0.000
7     5.00    7.29    5.86    3.92    5.67    4.10  0.210  0.107  0.683
8     2.49    2.55    4.66    7.15    6.26    7.87  0.038  0.623  0.339
9     1.50    3.35    5.70    9.86    4.83    1.17  0.000  0.000  1.000
10    8.19    7.72    9.56    6.61    4.15    3.64  0.000  0.005  0.995
11    2.43    9.54    9.15    4.41    9.18    7.85  0.041  0.037  0.922
12    2.71    3.24    4.56    6.22    7.89    9.93  0.981  0.019  0.001
13    5.96    4.34    5.26    8.63    9.81    9.40  0.975  0.001  0.024


来源:https://stackoverflow.com/questions/59938586/calculating-multinomial-logit-model-prediction-probabilities

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!