From the Magical Realm of Statistics: Generalized Additive Models with Integrated Smoothness Estimation

By Chris Rowe on October 4, 2020

Generalized additive models (GAMs) are extensions of generalized linear models (GLMs) in which the response variable is modeled as a linear function of unknown smooth functions of the predictor variables. The key distinction is that the smooth functions of predictor variables are unknown and can be estimated from the data.

I had previously encountered GAMs a handful of times, applied them once or twice to check for non-linearity in the relationships between variables, but never really dove into trying to understand them. Honestly, after spending some time reading about them over the last few days, I remain fuzzy on some of the technical details (the underlying math gets pretty heady), but what I have learned is that they are extremely cool and powerful. My main references for this post were The Elements of Statistical Learning (ESL; the authors of which originally developed GAMs), this article by Simon Wood (who is a major contributor in this space), and the documentation for the mgcv package, also created by Simon Wood.

Let’s check out some simple exmaples of what GAMs can do!

Simple Curve Fitting

A nice place to start when playing around with GAMs is trying to approximate functions of single variables. With just two dimensions, we can visualize everything! Let’s start with a nice a curvy function that they use in ESL:

\[ \begin{aligned} Y &= f(X) + \epsilon \\ \\ f(X) &= \frac{sin(12\times(x+0.2))}{x+0.2} \\ \\ \textrm{Where } X &\sim Unif(0,1) \textrm{ and } \epsilon \sim N(0,1) \end{aligned} \]

n <- 1000
x <- runif(n)
y_true <- sin(12*(x+0.2))/(x+0.2)
y_obs <- y_true + rnorm(n)

Before we dive into fitting some GAMs with the mgcv package, I think it would be helpful to give a brief overview of how the model fitting process works. Again, much of the mechanics are at the frontiers of my current intellectual reach, but I’ll try my best. Hopefully I don’t mischaracterize anything.

The general idea is that the the original predictor \(X\) is expanded using a specified number of functions of \(X\) (called basis functions). This is similar to including splines in a regression model, except that knots are not explicitly specified. The idea is that we want to obtain smoothness in the relationship betweeen \(X\) and \(Y\), and that smoothness is obtained by summing up several different basis functions of \(X\). The default approach in mgcv uses something called thin plate regression splines, which start with knots at the location of each data point \(x_i\), then employ some matrix wizardry (involving an eigendecomposition) to reduce the dimension of the original basis functions that corresponded to having knots at every data point. We specify the final target dimension (the default is 10: intercept + 9 other basis functions for a single predictor \(X\)). The parameters corresponding to this reduced number of basis functions are estimated using maximum likelihood with an added term that penalizes the curvature or “wiggliness” of the overall fit using some calculus wizardry (essentially penalizing the second derivative, which corresponds to curvature of a function). As with other penalized methods in regression (e.g., lasso or ridge regression), the severity of the penalty is controlled by a hyperparameter, typically called \(\lambda\) in the literature. Thus, fitting via this penalized maximum likelihood must strike a balance between model fit and the curvature/wiggliness of the fit; this is called regularization. Thus, the penalty acts to shrink the coefficients corresponding to each of the basis functions to achieve an appropriate degree of smoothness. A neat feature of mgcv is that the optimal \(\lambda\) is selected using either cross-validation or other methods that I don’t as readily understand (e.g., REML).

Okay, let’s fit a GAM to the data generated above using all the defaults, which involves thin plate regression splines with basis dimension of 10 and \(\lambda\) seleted using generalized cross-validation. Although it’s not necessary, I’ve specified these as arguments just to be explicit. We will then add the predicted values from our GAM to our plot above to see how good a job it does of approximating the true function.

library(mgcv)
mod_gam <- gam(y_obs~s(x, k=10, bs="tp"), method="GCV.Cp")
y_hat <- predict(mod_gam)

Wow! That’s a pretty close fit! Before familiarizing myself with GAMs, I probably would’ve tinkered around with polynomial terms or splines, picking the number and location of knots in a somewhat ad-hoc fashion until the fit looked reasonable. These automatic fitting algorithms preclude the need for all of those ad-hoc decisions. In more complicated scenarios, I’d probably probe some fit diagnositcs and make sure I’m not under- or over-parameterizing the model for my needs.

The documenations warns that the default choice of dimension 10 (i.e., the \(k\) argument) is arbitrary and that you will want to confirm this is approproiate for the problem at hand. The guidance I’ve seen suggests that it’s most important that you select a dimension that is at least large enough for the problem at hand and the penalty and shrinkage of the coefficients will handle the rest.

I personally think this is very cool and look forward to using these models in real applications. However, I really don’t do a lot of this type of curve fitting in my work. Let’s see how these can apply to a slightly different use case.

Covariate Adjustment in Regression Models

In many fields, it is common that one wants to understand the relationship between a single independent variable and a dependent variable, conditional on other covariates. In epidemiology for example, this arises when we want to investigate the relationship between some treatment/exposure and an outcome, but there exist confounding variables that are correlated with both the treatment and the outcome. In a regression modeling context, estimating the association between the treatment and the outcome without consideration of confounding variables might not be very informative or useful. This is because any relationship we observe between the two variables might be driven by their mutual relationships with the confounding variable(s).

To explore the value of GAMs in this setting, we’re going to use the following data generating process:

\[ X \sim beta(2,2) \] \[ g(Pr(T=1|X=x)) = x \]

\[ g(Pr(Y=1|T=t, X=x)) = (log(2)\times t) + \frac{sin(12\times(x+0.2))}{x+0.2} \]

\[ \textrm{Where } g(w) =log\bigg(\frac{w}{1-w}\bigg) \]

In words, we have a continuous beta distributed covariate \(X\); a binary treatment variable \(T\) that depends on \(X\) such that the log odds of receiving treatment increases linearly in \(X\); and binary outcome variable \(Y\) that depends on both \(X\) and \(T\) such that the log odds of experiencing the outcome increases linearly in \(T\) but non-linearly in \(X\).

We see that, conditional on \(X\), those who are treated have log(2) = 0.69 greater log odds of experiencing the outcome. This is the parameter that we want to estimate.

In an applied setting, we would not know any of this. We would just have a binary treatment variable, a binary outcome variable, and a continuous covariate. There are many ways one could model the outcome as a function of treatment and covariates, but we’re going to assume we had reason to use a logistic regression model.

Below, we are going to try to fit two different models:

  • A GAM in which we allow for Y to depend on \(T\) and smooth functions of \(X\).
  • A standard GLM in which we allow Y to depend linearly on both \(T\) and \(X\).

I’d venture to say that it’s not atypical for folks with less statistical training to just assume a linear relationship between covariates and an outcome in this sort of context (honestly, I’ve done this earlier in my career). Instead of just fitting each model once, we are actually going to randomly generate the data from the above process 1000 times so that we can assess each estimator’s bias, variance, and 95% confidence interval coverage.

Ultimately we are exploring a few things: (1) can the GAM recover the complex non-linear relationship between \(X\) and \(Y\) so as to obtain an unbiased estimate of the parameter of interest with valid inference?; and (2) what is the cost (in bias and confidence interval coverage) of assuming a linear relationship between \(X\) and \(Y\) in this case?

# initialize parameters and matrics for holding results
n <- 1000
true_effect <- log(2)
n_iter <- 1000
est <- matrix(rep(NA, n_iter*2), ncol=2)
ci_coverage <- matrix(rep(NA, n_iter*2), ncol=2)
for(i in 1:n_iter){

  # generate data
  x <- rbeta(n, 2, 2)
  t <- rbinom(n, 1, prob=x)
  y <- rbinom(n, 1, prob=plogis(true_effect*t + sin(12*(x+0.2))/(x+0.2)))
  
  # fit models
  mod_gam <- gam(y ~ t + s(x), family=binomial)
  mod_glm <- glm(y ~ t + x, family=binomial)
  
  # save estimates
  est[i,1] <- coef(mod_gam)[2]
  est[i,2] <- coef(mod_glm)[2]
 
  # assess whether confidence interval contains true effect
  lb_gam <- coef(mod_gam)[2] - sqrt(summary(mod_gam)$cov.unscaled[2,2])*qnorm(0.975)
  ub_gam <- coef(mod_gam)[2] + sqrt(summary(mod_gam)$cov.unscaled[2,2])*qnorm(0.975)
  ci_coverage[i, 1] <- ifelse(true_effect >= lb_gam & true_effect <= ub_gam, 1, 0)

  lb_glm <- coef(mod_glm)[2] - sqrt(summary(mod_glm)$cov.unscaled[2,2])*qnorm(0.975)
  ub_glm <- coef(mod_glm)[2] + sqrt(summary(mod_glm)$cov.unscaled[2,2])*qnorm(0.975)
  ci_coverage[i, 2] <- ifelse(true_effect >= lb_glm & true_effect <= ub_glm, 1, 0)  

}

Let’s check out the distribution of estimated log odds difference from the 1000 iterations for each model. The GAM on the left appears unbiased, with a mean estimate of 0.69, which is spot on the true odds ratio of log(2) = 0.69! However, the GLM on the right is biased, with a mean estimate of 0.53. Thus, it appears that the GAM was able to recover the complex relationship between \(X\) and \(Y\), allowing for an unbiased estimate of the treatment parameter. On the other hand, incorrectly specifying a linear relationship between \(X\) and \(Y\) biased our estimate of the main parameter of interest.

However, we can also tell from the plots that the GAM estimator has a a higher variance than the GLM estimator (0.025 vs. 0.02), which makes sense because we are estimating 11 parameters in the GAM model (intercept + treatment + 9 basis functions for the covariate) vs. 3 parameters in the GLM model (intercept + treatment + covariate).

What about confidence interval coverage? It turns out that the GAM had excellent 95% confidence interval coverage, with intervals that included the true parameter 96.1% of the time! The incorrectly specified GLM’s confidence intervals, however, only included the true parameter 79.1% of the time.

Thus, using GAMs can help us model complex non-linear relationships between confounding variables and our outcome, resulting in unbiased estimates and valid confidence intervals for the primary parameter of interest.

Very cool!

A Note of Caution

The GAMs explored here are just one of many tools for modeling complex relationships between variables. I could imagine many scenarios in which the methods explored here would a bit overkill, and relationships between variables could be adequately modeled with substantially fewer parameters (e.g., using polynomials or splines). This was meant as a simple demonstration of the methods, but if I were applying these in a real project I would have spent a lot more time on model checking and comparisons with simpler methods.

Also, although the inference appeared to be valid in this toy example, I’m still trying to understand for myself whether the fact that we are employing coefficient shrinkage has any implications for our estimates or inference. It’s well known that shrinkage adds a bit of bias in order to reduce variance, such that traditional frequentist inference for other shrinkage estimators (e.g., lasso) is invalid. This may not be relevant here, but it’s something I intend on further investigating.

Also, mgcv actually employs a Bayesian approach to variance estimation, which is new to me. I’m used to employing typical frequentist variance estimation with typical frequentist confidence intervals; so, I’m not even exactly sure what it means to be using Bayesian inference approaches that have good frequentist properties (as those in mgcv appear to have). Either way, I look foward to learning more about it!