Cargando…

An Optimization Method for Non-IID Federated Learning Based on Deep Reinforcement Learning

Federated learning (FL) is a distributed machine learning paradigm that enables a large number of clients to collaboratively train models without sharing data. However, when the private dataset between clients is not independent and identically distributed (non-IID), the local training objective is...

Descripción completa

Detalles Bibliográficos
Autores principales: Meng, Xutao, Li, Yong, Lu, Jianchao, Ren, Xianglin
Formato: Online Artículo Texto
Lenguaje:English
Publicado: MDPI 2023
Materias:
Acceso en línea:https://www.ncbi.nlm.nih.gov/pmc/articles/PMC10675381/
https://www.ncbi.nlm.nih.gov/pubmed/38005610
http://dx.doi.org/10.3390/s23229226
Descripción
Sumario:Federated learning (FL) is a distributed machine learning paradigm that enables a large number of clients to collaboratively train models without sharing data. However, when the private dataset between clients is not independent and identically distributed (non-IID), the local training objective is inconsistent with the global training objective, which possibly causes the convergence speed of FL to slow down, or even not converge. In this paper, we design a novel FL framework based on deep reinforcement learning (DRL), named FedRLCS. In FedRLCS, we primarily improved the greedy strategy and action space of the double DQN (DDQN) algorithm, enabling the server to select the optimal subset of clients from a non-IID dataset to participate in training, thereby accelerating model convergence and reaching the target accuracy in fewer communication epochs. In simulation experiments, we partition multiple datasets with different strategies to simulate non-IID on local clients. We adopt four models (LeNet-5, MobileNetV2, ResNet-18, ResNet-34) on the four datasets (CIFAR-10, CIFAR-100, NICO, Tiny ImageNet), respectively, and conduct comparative experiments with five state-of-the-art non-IID FL methods. Experimental results show that FedRLCS reduces the number of communication rounds required by 10–70% with the same target accuracy without increasing the computation and storage costs for all clients.