Extract and Visualize Model Trees from Sparklyr

邮差的信 提交于 2019-12-04 07:44:29

As of today (Spark 2.4.0 release already approved and waiting for the official announcement) your best bet*, without involving complex 3rd party tools (you can take a look MLeap for example), is probably to save the model and read back the specification:

ml_stage(iris_prediction_model, "random_forest") %>% 
  ml_save("/tmp/model")

rf_spec <- spark_read_parquet(sc, "rf", "/tmp/model/data/")

The result will be a Spark DataFrame with following schema:

rf_spec %>% 
  spark_dataframe() %>% 
  invoke("schema") %>% invoke("treeString") %>% 
  cat(sep = "\n")
root
 |-- treeID: integer (nullable = true)
 |-- nodeData: struct (nullable = true)
 |    |-- id: integer (nullable = true)
 |    |-- prediction: double (nullable = true)
 |    |-- impurity: double (nullable = true)
 |    |-- impurityStats: array (nullable = true)
 |    |    |-- element: double (containsNull = true)
 |    |-- gain: double (nullable = true)
 |    |-- leftChild: integer (nullable = true)
 |    |-- rightChild: integer (nullable = true)
 |    |-- split: struct (nullable = true)
 |    |    |-- featureIndex: integer (nullable = true)
 |    |    |-- leftCategoriesOrThreshold: array (nullable = true)
 |    |    |    |-- element: double (containsNull = true)
 |    |    |-- numCategories: integer (nullable = true)

providing information about all nodes and splits.

Feature mapping can be retrieved using column metadata:

meta <- iris_predictions %>% 
    select(features) %>% 
    spark_dataframe() %>% 
    invoke("schema") %>% invoke("apply", 0L) %>% 
    invoke("metadata") %>% 
    invoke("getMetadata", "ml_attr") %>% 
    invoke("getMetadata", "attrs") %>% 
    invoke("json") %>%
    jsonlite::fromJSON() %>% 
    dplyr::bind_rows() %>% 
    copy_to(sc, .) %>%
    rename(featureIndex = idx)

meta
# Source: spark<?> [?? x 2]
  featureIndex name        
*        <int> <chr>       
1            0 Sepal_Length
2            1 Sepal_Width 
3            2 Petal_Length
4            3 Petal_Width 

And labels mapping you've already retrieved:

labels <- tibble(prediction = seq_along(iris_labels) - 1, label = iris_labels) %>%
  copy_to(sc, .)

Finally you can combine all of these:

full_rf_spec <- rf_spec %>% 
  spark_dataframe() %>% 
  invoke("selectExpr", list("treeID", "nodeData.*", "nodeData.split.*")) %>% 
  sdf_register() %>% 
  select(-split, -impurityStats) %>% 
  left_join(meta, by = "featureIndex") %>% 
  left_join(labels, by = "prediction")

full_rf_spec
# Source: spark<?> [?? x 12]
   treeID    id prediction impurity    gain leftChild rightChild featureIndex
 *  <int> <int>      <dbl>    <dbl>   <dbl>     <int>      <int>        <int>
 1      0     0          1   0.636   0.379          1          2            2
 2      0     1          1   0      -1             -1         -1           -1
 3      0     2          0   0.440   0.367          3          8            2
 4      0     3          0   0.0555  0.0269         4          5            3
 5      0     4          0   0      -1             -1         -1           -1
 6      0     5          0   0.5     0.5            6          7            0
 7      0     6          0   0      -1             -1         -1           -1
 8      0     7          2   0      -1             -1         -1           -1
 9      0     8          2   0.111   0.0225         9         12            2
10      0     9          2   0.375   0.375         10         11            1
# ... with more rows, and 4 more variables: leftCategoriesOrThreshold <list>,
#   numCategories <int>, name <chr>, label <chr>

which, collected and separated by treeID, should give enough information** to mimic tree-like object (you can get a good understanding of the required structure by checking rpart::rpart.object documentation and/or unclassing an rpart model. tree::tree would require less work, but its plotting utilities are far from impressive), and build a decent plot.

An alternative path is to export your data to PMML using Sparklyr2PMML and use this representation.

You can also check How do I visualise / plot a decision tree in Apache Spark (PySpark 1.4.1)? which suggests third party Python package to solve the same problem.

If you don't need anything fancy you can create a crude plot with igraph:

library(igraph)

gframe <- full_rf_spec %>% 
  filter(treeID == 0) %>%   # Take the first tree
  mutate(
    leftCategoriesOrThreshold = ifelse(
      size(leftCategoriesOrThreshold) == 1,
      # Continuous variable case
      concat("<= ", round(concat_ws("", leftCategoriesOrThreshold), 3)),
      # Categorical variable case. Decoding variables might be involved
      # but can be achieved if needed, using column metadata or indexer labels
      concat("in {", concat_ws(",", leftCategoriesOrThreshold), "}")
    ),
    name = coalesce(name, label)) %>% 
 select(
   id, label, impurity, gain, 
   leftChild, rightChild, leftCategoriesOrThreshold, name) %>%
 collect()

vertices <- gframe %>% rename(label = name, name = id)

edges <- gframe %>%
  transmute(from = id, to = leftChild, label = leftCategoriesOrThreshold) %>% 
  union_all(gframe %>% select(from = id, to = rightChild)) %>% 
  filter(to != -1)

g <- igraph::graph_from_data_frame(edges, vertices = vertices)

plot(
  g, layout = layout_as_tree(g, root = c(1)),
  vertex.shape = "rectangle",  vertex.size = 45)


* It should improve in the nearest future, with newly introduced format agnostic ML writer API (which already supports PMML writer for selected models. Hopefully new models and formats will follow).

** If you work with categorical features you might want to map leftCategoriesOrThreshold to respective indexed levels.

If feature vector contains catagorical variables the output of jsonlite::fromJSON() will contain nominal group. For example if you had indexed column foo with three levels, assembled at the first position it will be something like this:

$nominal
     vals idx      name
1 a, b, c   1       foo

where vals column is a list of variable length vectors.

length(meta$nominal$vals[[1]])
[1] 3

The labels correspond to indices of this structure so in the example:

  • a has label 0.0 (not that labels are double precision floating point numbers, and numbering starts from 0.0)
  • b has label 1.0

and so on, and if you have split with leftCategoriesOrThreshold equal to let's say c(0.0, 2.0) it means that split is on labels {"a", "c"}.

Please also note that if categorical data is present you might have to process it before calling copy_to - it doesn't look like it supports complex fields as of now.

In Spark <= 2.3 you will have to use R code for mapping (on local structure some purrr should do just fine). In Spark 2.4 (not supported yet in sparklyr AFAIK) it might be easier to read metadata directly with Spark's JSON reader and map with its higher order functions.

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!