Cross-validation for NMF rank determination

Four methods for cross-validation of non-negative matrix factorizations

Cross-Validation for NMF

Rank is the most important hyperparameter in NMF. Finding that “sweet spot” rank can make the difference between learning a useful model that captures meaningful signal (but not noise) or learning a garbage model that misses good signal or focuses too much on useless noise.

Alex Williams has posted a great introduction to cross-validation for NMF on his blog. His review of the first two methods is particularly intuitive. However, the third method is both theoretically questionable and poor in practice.

There are three “unsupervised” cross-validation methods for NMF which I have found to be useful:

  • Bi-cross-validation, proposed by Perry and explained simply by Williams. The “Bi-” in “Bi-cross-validation” means that the model is trained on a block of randomly selected samples and features and evaluated on a non-intersecting block of samples and features. Thus, no samples or features in the test set are included in the training set. If the test and training sets contain samples in common, or features in common, NMF gets to “cheat” in training and directly infer patterns of regulation, and thus basic subsample-cross-validation with NMF does not work.
  • Imputation, described nicely by Lin and also reviewed in this StackExchange post by amoeba. Here, a small fraction of values (i.e. 5%) are “masked” and considered as missing during factorization, and the mean squared error of the imputed values is calculated after model training.
  • Robustness is simply the cosine similarity of matched factors in independent models trained on non-overlapping sample sets. The premise is that noise capture will result in low similarity, while efficient signal capture will result in high similarity. Furthermore, approximations which are too low-rank will not classify signals in the same manner, leading to poor factor matching.

Takeaways

  • The project method (bi-cross-validation) is useful for well-conditioned signal.
  • The robust method (similarity of independent factorizations) is generally the most informative for noisy data possibly suffering from signal dropout.
  • The imputation method is the slowest of the three, but generally the most sensitive.

Install RcppML

Install the development version of RcppML:

devtools::install_github("zdebruine/RcppML")
library(RcppML)
library(ggplot2)
library(cowplot)
library(umap)
library(irlba)

Simulated data

Simulated data is useful for demonstrating the utility of methods in response to adversarial perturbations such as noise or dropout.

We will first explore cross-validation using two simulated datasets generated with simulateNMF:

  1. data_clean will have no noise or signal dropout
  2. data_dirty contains the same signal as data_clean, but with a good amount of noise and dropout.
data_clean <- simulateNMF(nrow = 200, ncol = 200, k = 5, noise = 0, dropout = 0, seed = 123)
data_dirty <- simulateNMF(nrow = 200, ncol = 200, k = 5, noise = 0.5, dropout = 0.5, seed = 123)

Notice how data_clean contains only 5 non-zero singular values, while data_dirty does not:

We can use RcppML::crossValidate to determine the rank of each dataset. The default method uses “bi-cross-validation”. See ?crossValidate for details.

cv_clean <- crossValidate(data_clean, k = 1:10, method = "predict", reps = 3, seed = 123)
cv_dirty <- crossValidate(data_dirty, k = 1:10, method = "predict", reps = 3, seed = 123)
plot_grid(
  plot(cv_clean) + ggtitle("bi-cross-validation on\nclean dataset"),
  plot(cv_dirty) + ggtitle("bi-cross-validation on\ndirty dataset"), nrow = 1)

crossValidate also supports another method which compares robustness of two factorizations on independent sample subsets.

cv_clean <- crossValidate(data_clean, k = 1:10, method = "robust", reps = 3, seed = 123)
cv_dirty <- crossValidate(data_dirty, k = 1:10, method = "robust", reps = 3, seed = 123)
plot_grid(
  plot(cv_clean) + ggtitle("robust cross-validation on\nclean dataset"),
  plot(cv_dirty) + ggtitle("robust cross-validation on\ndirty dataset"), nrow = 1)

This second method does better on ill-conditioned data because it measures the robustness between independent factorizations.

Finally, we can use the impute method:

cv_clean <- crossValidate(data_clean, k = 1:10, method = "impute", reps = 3, seed = 123)
cv_dirty <- crossValidate(data_dirty, k = 1:10, method = "impute", reps = 3, seed = 123)
plot_grid(
  plot(cv_clean) + ggtitle("impute cross-validation on\nclean dataset") + scale_y_continuous(trans = "log10"),
  plot(cv_dirty) + ggtitle("impute cross-validation on\ndirty dataset") + scale_y_continuous(trans = "log10"), nrow = 1)

For real datasets, it is important to experiment with both cross-validation methods and to explore multi-resolution analysis or other objectives where appropriate.

Let’s take a look at a real dataset:

Finding the rank of the hawaiibirds dataset

data(hawaiibirds)
A <- hawaiibirds$counts
cv_predict <- crossValidate(A, k = 1:20, method = "predict", reps = 3, seed = 123)
cv_robust <- crossValidate(A, k = 1:20, method = "robust", reps = 3, seed = 123)
cv_impute <- crossValidate(A, k = 1:20, method = "impute", reps = 3, seed = 123)
plot_grid(
  plot(cv_predict) + ggtitle("method = 'predict'") + theme(legend.position = "none"),
  plot(cv_robust) + ggtitle("method = 'robust'") + theme(legend.position = "none"),
  plot(cv_impute) + ggtitle("method = 'impute'") + scale_y_continuous(trans = "log10") + theme(legend.position = "none"),
  get_legend(plot(cv_predict)), rel_widths = c(1, 1, 1, 0.4), nrow = 1, labels = "auto")

Finding the rank of the aml dataset

data(aml)
cv_impute <- crossValidate(aml, k = 2:14, method = "impute", reps = 3, seed = 123)
plot(cv_impute) + scale_y_continuous(trans = "log10")

Technical considerations

Runtime is a major consideration for large datasets. Unfortunately, missing value imputation can be very slow.

Perturb

Compare missing value imputation with perturb (zeros) and perturb (random):

data(hawaiibirds)
data(aml)
data(movielens)
library(Seurat)
## Warning: package 'Seurat' was built under R version 4.0.5
## Attaching SeuratObject
library(SeuratData)
## Registered S3 method overwritten by 'cli':
##   method     from         
##   print.boxx spatstat.geom
## -- Installed datasets ------------------------------------- SeuratData v0.2.1 --
## v bmcite       0.3.0                    v pbmc3k       3.1.4
## v hcabm40k     3.0.0                    v pbmcMultiome 0.1.0
## v ifnb         3.1.0                    v pbmcsca      3.0.0
## v panc8        3.0.2                    v stxBrain     0.1.1
## -------------------------------------- Key -------------------------------------
## v Dataset loaded successfully
## > Dataset built with a newer version of Seurat than installed
## (?) Unknown version of Seurat installed
pbmc3k
## An object of class Seurat 
## 13714 features across 2700 samples within 1 assay 
## Active assay: RNA (13714 features, 0 variable features)
A <- pbmc3k@assays$RNA@counts

n <- 0.2
method = "impute"
cv1 <- crossValidate(A, k = 1:15, method = method, reps = 3, seed = 123, perturb_to = "random", n = n)
cv2 <- crossValidate(aml, k = 1:15, method = method, reps = 3, seed = 123, perturb_to = "random", n = n)
cv3 <- crossValidate(movielens$ratings, k = 1:15, method = method, reps = 3, seed = 123, perturb_to = "random", n = n)
cv4 <- crossValidate(hawaiibirds$counts, k = 1:15, method = method, reps = 3, seed = 123, perturb_to = "random", n = n)
plot_grid(
  plot(cv1) + theme(legend.position = "none") + scale_y_continuous(trans = "log10"),
  plot(cv2) + theme(legend.position = "none") + scale_y_continuous(trans = "log10"),
  plot(cv3) + theme(legend.position = "none") + scale_y_continuous(trans = "log10"),
  plot(cv4) + theme(legend.position = "none") + scale_y_continuous(trans = "log10"),
  nrow = 2)

Zach DeBruine
Zach DeBruine
Assistant Professor of Bioinformatics

Assistant Professor of Bioinformatics at Grand Valley State University. Interested in single-cell experiments and dimension reduction. I love simple, fast, and common sense machine learning.