问题
I am using the R programming language. I incorporated my own code along with a lengthy tutorial over here : https://michael.hahsler.net/SMU/EMIS7332/R/viz_classifier.html . In the end, I produced a visual "plot" (see the end of this code, "final_plot")
library(cluster)
library(Rtsne)
library(dplyr)
library(randomForest)
library(caret)
library(ggplot2)
library(plotly)
#PART 1 : Create Data
#generate 4 random variables : response_variable ~ var_1 , var_2, var_3
var_1 <- rnorm(10000,1,4)
var_2<-rnorm(10000,10,5)
var_3 <- sample( LETTERS[1:4], 10000, replace=TRUE, prob=c(0.1, 0.2, 0.65, 0.05) )
response_variable <- sample( LETTERS[1:2], 10000, replace=TRUE, prob=c(0.4, 0.6) )
#put them into a data frame called "f"
f <- data.frame(var_1, var_2, var_3, response_variable)
#declare var_3 and response_variable as factors
f$response_variable = as.factor(f$response_variable)
f$var_3 = as.factor(f$var_3)
#create id
f$ID <- seq_along(f[,1])
#PART 2: random forest
#split data into train set and test set
index = createDataPartition(f$response_variable, p=0.7, list = FALSE)
train = f[index,]
test = f[-index,]
#create random forest statistical model
rf = randomForest(response_variable ~ var_1 + var_2 + var_3, data=train, ntree=20, mtry=2)
#have the model predict the test set
pred = predict(rf, test, type = "prob")
labels = as.factor(ifelse(pred[,2]>0.5, "A", "B"))
confusionMatrix(labels, test$response_variable)
#PART 3: Visualize in 2D (source: https://dpmartin42.github.io/posts/r/cluster-mixed-types)
gower_dist <- daisy(test[, -c(4,5)],
metric = "gower")
gower_mat <- as.matrix(gower_dist)
labels = data.frame(labels)
labels$ID = test$ID
tsne_obj <- Rtsne(gower_dist, is_distance = TRUE)
tsne_data <- tsne_obj$Y %>%
data.frame() %>%
setNames(c("X", "Y")) %>%
mutate(cluster = factor(labels$labels),
name = labels$ID)
plot = ggplot(aes(x = X, y = Y), data = tsne_data) +
geom_point(aes(color = labels$labels))
plotly_plot = ggplotly(plot)
a = tsne_obj$Y
a = data.frame(a)
data = a
data$class = labels$labels
decisionplot <- function(model, data, class = NULL, predict_type = "class",
resolution = 100, showgrid = TRUE, ...) {
if(!is.null(class)) cl <- data[,class] else cl <- 1
data <- data[,1:2]
k <- length(unique(cl))
plot(data, col = as.integer(cl)+1L, pch = as.integer(cl)+1L, ...)
# make grid
r <- sapply(data, range, na.rm = TRUE)
xs <- seq(r[1,1], r[2,1], length.out = resolution)
ys <- seq(r[1,2], r[2,2], length.out = resolution)
g <- cbind(rep(xs, each=resolution), rep(ys, time = resolution))
colnames(g) <- colnames(r)
g <- as.data.frame(g)
### guess how to get class labels from predict
### (unfortunately not very consistent between models)
p <- predict(model, g, type = predict_type)
if(is.list(p)) p <- p$class
p <- as.factor(p)
if(showgrid) points(g, col = as.integer(p)+1L, pch = ".")
z <- matrix(as.integer(p), nrow = resolution, byrow = TRUE)
contour(xs, ys, z, add = TRUE, drawlabels = FALSE,
lwd = 2, levels = (1:(k-1))+.5)
invisible(z)
}
model <- randomForest(class ~ ., data=data, mtry=2, ntrees=500)
final_plot = decisionplot(model, data, class = "class", main = "rf (1)")
Now, I would like to turn this into an "interactive" plot using the plotly library in R:
plotly_plot = ggplotly(final_plot)
But I got the following error:
Error in UseMethod("ggplotly", p) :
no applicable method for 'ggplotly' applied to an object of class "c('matrix', 'array', 'integer', 'numeric')"
Is there a way to convert "Regular" plots to "ggplot" in R? Can my "final_plot" be passed through a "plotly" object?
回答1:
As @mischva11 commented, I think it is easier to create the ggplot from scratch. Your function is actually returning a matrix and not a kind of plot object. the plot
and countour
functions draw the plots directly in the active graphic window. I am not sure if there is a way to convert these base plots to ggplot (maybe there is).
Here is a way to create a similar plot as you have in ggplot and then convert it to plotly.
decisionplot <- function(model, data, class = NULL, predict_type = "class", resolution = 100, showgrid = TRUE) {
# create ggplot with minimal theme and no grid lines
g <- ggplot() + theme_minimal() + theme(panel.grid = element_blank())
# make grid values for contour and grid points
r <- sapply(data[ ,1:2], range, na.rm = TRUE)
xs <- seq(r[1,1], r[2,1], length.out = resolution)
ys <- seq(r[1,2], r[2,2], length.out = resolution)
g1 <- cbind(rep(xs, each=resolution), rep(ys, time = resolution))
colnames(g1) <- colnames(r)
g1 <- as.data.frame(g1)
### guess how to get class labels from predict
### (unfortunately not very consistent between models)
p <- predict(model, g1, type = predict_type)
if(is.list(p)) p <- p$class
g1$class <- as.factor(p)
if(showgrid) {
# add labeled grid points to ggplot
g <- g + geom_point(data=g1, aes(x=X1, y=X2, col = class), shape = ".")
}
# add points to plot
g <- g + geom_point(data=data, aes(x=X1, y=X2, col = class, shape = class))
# add contour curves
g <- g + geom_contour(data=g1, aes(x=X1, y=X2, z=as.integer(class)), colour='black', linetype=1, size=rel(0.2), bins=length(unique(g1$class)))
# return ggplot object
return(g)
}
# get ggplot object
final_plot <- decisionplot(model, data, class = "class")
# convert to plotly
ggplotly(final_plot)
This works. The final plot does not look that good, but you can play around with the parameters.
One thing that in my opinion could make the final plot better is to use geom_raster
to plot the regions with different label predictions (instead of plotting the small points). However, when I did this the conversion to plotly
took forever (I actually gave up). I think there is an issue in the conversion to plotly when you use discrete labels for geom_raster
, because when i converted the discrete labels to numeric values, it converted to plotly very fast.
Another option is to work directly in plot_ly, but I don't have much experience on this.
Hope this works.
来源:https://stackoverflow.com/questions/65406196/convert-regular-plot-to-ggplot-object-and-then-plotly