Abstract
A central problem in machine learning is as follows: How should we train models using data generated from a collection of clients/environments, if we know that these models will be deployed in a new and unseen environment? In the setting of few-shot learning, two prominent approaches are: (a) develop a modeling framework that is “primed” to adapt, such as Model Adaptive Meta Learning (MAML), or (b) develop a common model using federated learning (such as FedAvg), and then fine tune the model for the deployment environment. We study both these approaches in the multi-task linear representation setting. We show that the reason behind generalizability of the models in new environments trained through either of these approaches is that the dynamics of training induces the models to evolve toward the common data representation among the clients’ tasks. In both cases, the structure of the bi-level update at each iteration (an inner and outer update with MAML, and a local and global update with FedAvg) holds the key — the diversity among client data distributions are exploited via inner/local updates, and induces the outer/global updates to bring the representation closer to the ground-truth. In both these settings, these are the first results that formally show representation learning, and derive exponentially fast convergence to the ground-truth representation. Based on joint work with Liam Collins, Hamed Hassani, Aryan Mokhtari, and Sewoong Oh. Papers: https://arxiv.org/abs/2202.03483 , https://arxiv.org/abs/2205.13692