Part of Advances in Neural Information Processing Systems 35 (NeurIPS 2022) Main Conference Track
Kumar K Agrawal, Arnab Kumar Mondal, Arna Ghosh, Blake Richards
Self-Supervised Learning (SSL) with large-scale unlabelled datasets enables learning useful representations for multiple downstream tasks. However, assessing the quality of such representations efficiently poses nontrivial challenges. Existing approaches train linear probes (with frozen features) to evaluate performance on a given task. This is expensive both computationally, since it requires retraining a new prediction head for each downstream task, and statistically, requires task-specific labels for multiple tasks. This poses a natural question, how do we efficiently determine the "goodness" of representations learned with SSL across a wide range of potential downstream tasks? In particular, a task-agnostic statistical measure of representation quality, that predicts generalization without explicit downstream task evaluation, would be highly desirable. In this work, we analyze characteristics of learned representations $\mathbf{f_\theta}$, in well-trained neural networks with canonical architectures \& across SSL objectives. We observe that the eigenspectrum of the empirical feature covariance $\mathrm{Cov}(\mathbf{f_\theta}$) can be well approximated with the family of power-law distribution. We analytically and empirically (using multiple datasets, e.g. CIFAR, STL10, MIT67, ImageNet) demonstrate that the decay coefficient $\alpha$ serves as a measure of representation quality for tasks that are solvable with a linear readout, i.e. there exist well-defined intervals for $\alpha$ where models exhibit excellent downstream generalization. Furthermore, our experiments suggest that key design parameters in SSL algorithms, such as BarlowTwins, implicitly modulate the decay coefficient of the eigenspectrum ($\alpha$). As $\alpha$ depends only on the features themselves, this measure for model selection with hyperparameter tuning for BarlowTwins enables search with less compute.