Asymptotics of feature learning in two-layer networks after one gradient-step

Abstract:

In this manuscript we investigate the problem of how two-layer neural networks learn features from data, and improve over the kernel regime, after being trained with a single gradient descent step. Leveraging a connection from (Ba et al., 2022) with a non-linear spiked matrix model and recent progress on Gaussian universality (Dandi et al., 2023), we provide an exact asymptotic description of the generalization error in the high-dimensional limit where the number of samples n, the width p and the input dimension d grow at a proportional rate. We characterize exactly how adapting to the data is crucial for the network to efficiently learn non-linear functions in the direction of the gradient -- where at initialization it can only express linear functions in this regime. To our knowledge, our results provides the first tight description of the impact of feature learning in the generalization of two-layer neural networks in the large learning rate regime η=Θd(d), beyond perturbative finite width corrections of the conjugate and neural tangent kernels.

arXiv:2402.04980 [stat.ML]

Last updated on 02/09/2024