Applying lm() and predict() to multiple columns in a data frame

淺唱寂寞╮ 提交于 2019-12-04 17:26:19
李哲源

I had an inclination to close your question as a duplicate to Fitting a linear model with multiple LHS, but sadly the prediction issue is not addressed over there. On the other hand, Prediction of 'mlm' linear model object from lm() talks about prediction, but is a little bit far off your situation, as you work with formula interface instead of matrix interface.

I did not manage to locate a perfect duplicate target in "mlm" tag. So I think it a good idea to contribute another answer for this tag. As I said in linked questions, predict.mlm does not support se.fit, and at the moment, this is also a missing issue in "mlm" tag. So I would take this chance to fill such gap.


Here is a function to get standard error of prediction:

f <- function (mlmObject, newdata) {
  ## model formula
  form <- formula(mlmObject)
  ## drop response (LHS)
  form[[2]] <- NULL
  ## prediction matrix
  X <- model.matrix(form, newdata)
  Q <- forwardsolve(t(qr.R(mlmObject$qr)), t(X))
  ## unscaled prediction standard error
  unscaled.se <- sqrt(colSums(Q ^ 2))
  ## residual standard error
  sigma <- sqrt(colSums(residuals(mlmObject) ^ 2) / mlmObject$df.residual)
  ## scaled prediction standard error
  tcrossprod(unscaled.se, sigma)
  }

For your given example, you can do

## fit an `mlm`
fit <- lm(cbind(x3, x4, x5) ~ x1 + x2, data = train)

## prediction (mean only)
pred <- predict(fit, newdata = test)

#            x3          x4         x5
#1  0.555956679  0.38628159 0.60649819
#2  0.003610108  0.47653430 0.95848375
#3 -0.458483755  0.48014440 1.27256318
#4 -0.379061372 -0.03610108 1.35920578
#5  1.288808664  0.12274368 0.17870036
#6  1.389891697  0.46570397 0.01624549

## prediction error
pred.se <- f(fit, newdata = test)

#          [,1]      [,2]      [,3]
#[1,] 0.1974039 0.3321300 0.2976205
#[2,] 0.3254108 0.5475000 0.4906129
#[3,] 0.5071956 0.8533510 0.7646849
#[4,] 0.6583707 1.1077014 0.9926075
#[5,] 0.5049637 0.8495959 0.7613200
#[6,] 0.3552794 0.5977537 0.5356451

We can verify that f is correct:

## `lm1`, `lm2` and `lm3` are defined in your question
predict(lm1, test, se.fit = TRUE)$se.fit
#        1         2         3         4         5         6 
#0.1974039 0.3254108 0.5071956 0.6583707 0.5049637 0.3552794 

predict(lm2, test, se.fit = TRUE)$se.fit
#        1         2         3         4         5         6 
#0.3321300 0.5475000 0.8533510 1.1077014 0.8495959 0.5977537 

predict(lm3, test, se.fit = TRUE)$se.fit
#        1         2         3         4         5         6 
#0.2976205 0.4906129 0.7646849 0.9926075 0.7613200 0.5356451 
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!