Making Plots with Purrr
Published:
I was recently asked to make 4 plots for a collaborator. The plots are all the same, just a scatter plot and a non-linear trend line. Every time I have to do something repetitive, I wince, especially with respect to plots. I thought I would take this opportunity to write a short blog post on how to use functional programming in R to make the same plot for similar yet different data.
Side note: The result could probably be obtained through clever use of ggplot2’s geom_smooth
, but I don’t think that workflow would generalize well.
Step 1: Understand the Problem
The data contain some clinical variables and the concentration of a drug in the body. My collaborator just wants a scatter plot and a trend line. Linear regression isn’t appropriate because concentration must be positive and preliminary looks at the data show a non-linear trend. We both agreed a log transform is appropriate for the concentration data. So the plan is to:
- Log the concentration data
- Perorm a linear regression on the log scale
- Transform the data and the predictions back to the original scale
- Plot the resulting exponential trend as well as the original data.
I have to do that for 4 variables.
Step 2: Put 4 Data Sets Into 1 Dataframe
I’ve loaded the data into a variable called c.data
. There are 5 columns all in all (concentration and 4 covariates). What I want is 4 seperate data sets, each with 2 columns (1 for concentration, 1 for the covariate). If I use gather
, group_by
the covariate names, and then nest
I can get what I want. I would suggest you make your own data set as I have described and follow along so you can see what is happening.
library(tidyverse)
# Make 4 datasets and store in one dataset
step.2 = c.data %>%
gather(var, value, -Concentration) %>%
group_by(var) %>%
nest(.key = 'obs.data')
Now I have 4 datasets which are essentially the same and they all live in step.2
.
Step 3: Map The Data To a Linear Regression
I can use map
to store models fitted to each covariate. Each of the smaller covariate has a column called Concentration
and another called value
. If we had one of these smaller datasets infront of us, we would do lm(log(Concentration) ~ . , data = smaller.data)
to fit the linear model. Since we have a column of datasets, I do
step.3 = step.2 %>%
mutate(
model = map(obs.data, ~lm(log(Concentration) ~ ., data = .x) #Fit models here
)
Step 4: Set Up New Data To Predict On
In order to plot a trend line, I need a grid of new data. The library modelr
has some great utilities for this sort of thing. The data_grid
function does just what I want.
library(modelr)
step.4 = step3 %>%
mutate(
newdata = map(obs.data, ~ data_grid(.x, value = seq_range(value,20)) )
)
Now I have another new column called newdata
which has a dataframe of evenly spaced observations for each covariate. All that is left to do is to predict on this data using the model we fit in step 3.
Step 5: Predict & Transform
Remember, I took the log of concentration, so I will have to exponentiate my predictions. I have to pass both my model and the new data into the predict
fucntion, so that means I have to use map2
.
step.5 = step.4 %>%
mutate(
preds = map2(model, newdata, ~ exp(predict(.x, newdata = .y)) #Dont forget the exp!
)
I’ll be using ggplot2 to make the plots, so it is best if I have the x’s and the y’s for the trend line in the same data frame.
step.5 = step.5 %>%
mutate(
pred.data = map2(newdata,preds, ~ .x %>% bind_cols(pred = .y))
)
Now I have another column, pred.data
which has the data I’ll need for drawing the trend line.
Step 6: Make a Function For the Plots
I’ve written a nice helper function to help me draw the plots I want. It will take in as arguments 2 data sets, plotting a scatter plot for 1 data set, and a line plot for the other. Here is the function:
make.plot<-function(pred.data, obs.data, vars){
plot = ggplot()+
geom_point(data = obs.data, aes(value, Concentration), shape = 21, fill = 'gray')+
geom_line(data = pred.data, aes(value,pred), color = 'red' )+
labs(x = vars)+
theme_classic()
return(plot)
}
Step 7: Reduce!
I can use pmap
to pass in a named list to my make.plot
function and then use patchwork
and reduce
to combine all the plots into a single figure!
library(patchwork)
arguments = list(pred.data = plotting.data$pred.data,
obs.data = plotting.data$obs.data,
vars = c('Var1','Var2','Var3','Var4'))
plots<-pmap(arguments, make.plot)
final = reduce(plots,`+`)
ggsave('purrrplot.png', plot= final, dpi = 800)
Here is the final product! Another for
loop cleverly averted.
