Federated Learning: Massively distributed learning
Modern mobile devices have access to a vast amount of information. This information can be leveraged using machine learning to enhance user experience. For example, a model can learn how a user types and suggest the next word accordingly. Photos that the user takes can be tagged automatically. This data is also essentially labelled as the labels can be easily inferred from users’ interactions with the data. However, this data is also usually private. Also, the amount of data across a very large number of users is vast. For these reasons, it is infeasible to collect all the data from all the users on a server and then train a model on that. However, modern mobile devices house powerful hardware. This gives rise to federated learning which utilizes this distributed network of devices to learn a shared model.
Traditional machine learning requires data to be centralized. When we move away from this and into federated learning, we face several new challenges. The federated optimization problem has the following properties:
- Data is non-IID: Data is user dependent and the data of a particular user is not representative of the data of all users.
- Data is unbalanced: Different users have different quantities of data.
- Data is massively distributed: Data exists across a very large number of devices and the number of devices is much greater than the number of examples per client.
- Limited Communication: These devices are not always online and communication is expensive.
McMahan et al introduced the FederatedAveraging algorithm which addresses the unbalanced and non-IID nature of the data and trains a model using a federation of devices.
The model weights are updated using synchronous SGD. There is a set of K clients. Each client has a fixed local dataset. In each round, a random fraction C of clients is chosen and the current model parameters are shared with these clients. Each chosen client runs E local epochs using batch size B and sends the weight updates to the server which averages the updates and applies them to get the new model parameters.
The authors run several experiments to evaluate the approach and study the effect of the hyperparameters. The tasks chosen are MNIST digit recognition and language modelling with a dataset built from “The Complete Works of William Shakespeare”. For the image task, the number of clients is 100 whereas for the language modelling task, the number of clients is 1146.
For MNIST digit classification, they train a 2 layer multilayer perceptron and a Convolutional Neural Network (CNN). They train the models for iid data and non-iid data. In the iid setting, the data is shuffled and each client receives 600 examples. In the non-iid setting, the data is first sorted by label and divided into 200 shards of size 300. Then each client is assigned any 2 shards. This is an extremely non-iid setting as each client will only have data for two classes.
For the language modelling task, a client dataset is created for each speaking role in each play with at least 2 lines. Each client dataset is split into train set(80%) and test set(20%). This data is highly unbalanced as roles are very different. A iid version of the dataset is created by balancing the data. A character level LSTM is trained to learn the language model.
The authors first experiment with increasing parallelism for the MNIST task. C controls the multi-client parallelism. The following table shows the number of communication rounds required to reach target accuracy for the particular setting of C and B (local batch size). B = ∞ means that the entire local dataset is considered as one batch.
As we can see from the table above, increasing C to 0.1 shows significant improvement in the number of communication rounds required to reach target accuracy when B=10. The improvement is even more pronounced for non-iid data. Based on the results in the table, the authors set C to be 0.1 for all experiments which provides a good balance between convergence rate and computational efficiency.
Next the authors study the effect of increasing computation carried out at each client. The parameters B (local batch size) and E (local epochs) control the computation per client. Computation can be increased by increasing E or by decreasing B. The following table and graph show the effect of varying B and E on the number of communication rounds required to reach the target accuracy. The value of u is the expected number of updates per client per round. It is a combination of E and B and can be considered as the measure of amount of computation per client per round.
The table shows that increasing computation per client leads to faster convergence. This is true for both the CNN and the LSTM model and for both iid and non iid data. Real world federated learning data most closely resembles the non-iid and unbalanced nature of the language modelling task here, for which a speedup of 95x is observed. The authors infer that the speedup is very large in this case as some clients have a large amount of data, which makes increased local training more valuable.
The authors conduct further experiments on the CIFAR classification task and on large-scale LSTM models. Their approach works well for these as well. The authors demonstrate that federated learning is practical and provide an algorithm which can be used to train a model on a federation of devices.
While the introductory work addresses certain aspects of the federated optimization problem, it does not account for the crucial aspect of communication and related costs. In the paper “Federated Learning: Strategies for improving communication efficiency”, Konecny et al study different methods of reducing communication costs and their impact.
When the data is so massively distributed as in federated learning, the computation is no longer the bottleneck. The bottleneck is the communication. All the different devices in the federation have to send updates to the parameter server. Communication efficiency of utmost importance. The authors put forward two ways of reducing communication costs: structured updates and sketched updates. They show that it is possible to reduce communication costs by two orders of magnitude without affecting the performance.
First we must look at what happens in federated learning. The federation of devices is learning a shared model. Parameters of the model are matrices W of shape d1xd2. Ws are updated at each step using some update H, calculated using techniques such as gradient descent. Each client computes an update Hi. This update is sent to the server which then averages all the updates and applies them to get the new parameters. The update H is what is sent to the server.
First we will look at Structured Updates. In this method, the update H is restricted to have a pre-specified structure. There are two ways of doing this: “low rank” and “random mask”. In low rank, every update of the local model, H which is of shape d1xd2 is restricted to be a lower rank (k) matrix by defining it as the product of two matrices, A (d1 x k) and B (k x d2). A is generated randomly and is considered constant while B is learned. We don’t need to communicate A, as only communicating the random seed used to generate it is sufficient. Thus a smaller matrix needs to be communicated and the cost is reduced. In the random mask technique, H is restricted to be a sparse matrix, using a predefined sparsity pattern, which can again be communicated using only the seed and hence only the non-zero values of the matrix need to be communicated.
The next method is Sketched updates. In this method, the full H is first calculated. Then it is compressed and the compressed update is sent to the server. The server decodes the updates and then uses them. There are multiple ways of compressing the updates. The authors study two such methods. First is subsampling. Instead of sending the full update, a random subset of the update is sent. The server then averages over all the updates from all the different clients. This is similar to the random mask approach mentioned in the previous section. The expected value of the update remains unchanged. The second method is probabilistic quantization. In this approach, the weights are quantized. Each scalar in the update matrix H is quantized to one bit. If h = vec(H), and h_max is max of h and h_min is min of h, the compressed update is given by:
j indexes the elements in h. This method provides 32 times compression compared to 4 byte floats. Scalars can also be quantized to multiple bits which reduces compression but loses less information.
The authors conduct several experiments to evaluate the different methods. Experiments are conducted on the CIFAR-10 image classification task for various settings .Experiments are also conducted on reddit post data for next word prediction. These experiments are particularly useful as they very accurately simulate federated learning scenarios as each post can be attributed to a particular user.
First we will look at the performance of structured updates for the CIFAR task. The upper graph is for “low rank” and the lower graph is for “random mask”. Mode refers to the percentage of data retained. That is, if mode is 100, no size reduction has occurred. If mode is 6.25, size of the update is 6.25% of what the size would be without reduction.
As we can see from the graphs above, as we reduce the size of the update, the performance of low rank reduces significantly. However, “random mask” shows the same performance even when the update size is much smaller than the actual update size. From this we can conclude that random mask performs better than low rank.
Next the authors compare the performance of structured updates with sketched updates. Low rank is not considered as random mask gives better performance. The following graphs show the comparison of structured random mask updates and sketched updates without quantization on the CIFAR data.
As we can see from the graphs, random mask performs better than sketched updates. This is expected as a significant amount of information is thrown away in sketched updates. However, we also see that in sketched updates the accuracy increases quicker in the initial stages compared to random mask.
Next the authors study the impact of applying rotations on the quantizations. They find that applying a random rotation on h (vectorized version of H) before quantization leads to better results. The graph below shows the comparison of sketched updates, combining preprocessing the updates with rotations, quantization and subsampling on the CIFAR data. rotations=I indicates no rotation has been applied. rotations=HD indicates random rotation has been applied before quantization.
As we can see from the graphs above, increasing the quantization bits increases the accuracy as less information is lost. However, this results in larger update size. We also see that random rotation with quantization is needed for sketched updates to provide good performance.
Next the authors study the performance of sketched updates for the LSTM trained to predict next word on reddit data.
We see that in this case as well, preprocessing with random rotation is very helpful. In this experiment the number of clients is 2000 and only 50 clients are sampled at each round. Since so few clients are sampled each round, most data is not even touched and yet the model shows good top-1 accuracy.
Finally, the authors study the effect of the number of clients used in a single round on the convergence speed. The experiment is carried out with subsampling and quantization. The following graph shows the results of this experiment.
The graph shows that given enough number of clients, subsampling only 1% of the elements from each update gives the same performance as subsampling 10% of the elements with lesser clients. This shows that if we have enough clients, we can reduce the communication from each client. This points to a very important and practical tradeoff in federated learning: “one can select more clients in each round while having each of them communicate less (e.g., more aggressive subsampling), and obtain the same accuracy as using fewer clients, but having each of them communicate more”. The former is preferred when many clients are available but they have low bandwidth as is the case in practical federated learning.
The authors demonstrate that it is possible to reduce the cost of communication by 2 orders of magnitude and achieve similar performance using federated learning. They also identify a key tradeoff in federated learning.
The two papers covered in this blog introduce the concept of federated learning, the motivation behind it and show how it can be made practically feasible by addressing the unique challenges it throws up.
References:
- “Communication-Efficient Learning of Deep Networks from Decentralized Data” by McMahan et al : https://arxiv.org/pdf/1602.05629.pdf
2. “FEDERATED LEARNING: STRATEGIES FOR IMPROVING COMMUNICATION EFFICIENCY” by Konecny et al : https://arxiv.org/pdf/1610.05492.pdf