上一篇文章稍微入门了一下Tensorflow Federated框架,但是目前来说,要实现联邦学习(实验)算法用它还是“杀鸡用牛刀”。因此,在一番探索后,我发现朴素Tensorflow也能实现联邦学习算法,甚至还可以手动分开Client端和Server端代码,逻辑更清晰。稍作修改,添加网络传输后甚至可以部署到分布式场景,实现真正意义上的联邦学习(性能估计不会太好hhh)。
这篇文章中,我将分享一种实现联邦学习的方法,它具有以下优点:
不需要读写文件来保存、切换Client模型不需要在每次epoch重新初始化Client变量内存占用尽可能小(参数量仅翻一倍,即Client端+Server端)切换Client只增加了一些赋值操作继续阅读之前,默认大家对联邦学习有一些了解,并达成以下共识:
学习的目标是一个更好的模型,由Server保管,Clients提供更新数据(Data)由Clients保管、使用文章的代码环境、库依赖:
Python 3.7Tensorflow v1.14.xtqdm(一个Python模块)接下来本文会分成Client端、Server端代码设计与实现进行讲解。懒得看讲解的胖友可以直接拉到最后的完整代码章节,共有四个代码文件,运行python Server.py即可以立马体验原汁原味的(单机模拟)联邦学习。
从算法角度,本文实现的是传递更新后的模型参数的实验代码,另一篇文章(https://blog.csdn.net/Mr_Zing/article/details/109496824)实现了传递梯度的实验代码。
明确一下Client端的任务,包含下面三个步骤:
将Server端发来的模型变量加载到模型上用自己的所有数据更新当前模型将更新后的模型变量发回给Server在这些任务下,我们可以设计出Client代码需要具备的一些功能:
创建、训练Tensorflow模型(也就是计算图)加载Server端发过来的模型变量值提取当前模型的变量值,发送给Server维护自己的数据集用于训练其实,仔细一想也就比平时写的tf模型代码多了个加载、提取模型变量。假设Client类已经构建好了模型,那么sess.run()一下每个变量,即可得到模型变量的值了。下面的代码展示了部分Clients类的定义,get_client_vars函数将返回计算图中所有可训练的变量值:
class Clients: def __init__(self, input_shape, num_classes, learning_rate, clients_num): self.graph = tf.Graph() self.sess = tf.Session(graph=self.graph) """ 本函数未完待续... """ def get_client_vars(self): """ Return all of the variables list """ with self.graph.as_default(): client_vars = self.sess.run(tf.trainable_variables()) return client_vars加载Server端发过来的global_vars到模型变量上,核心在于tf.Variable.load()函数,把一个Tensor的值加载到模型变量中,例如:
variable.load(tensor, sess)将tensor(类型为tf.Tensor)的值赋值给variable(类型为tf.Varibale),sess是tf.Session。
如果要把整个模型中的变量值都加载,可以用tf.trainable_variables()获取计算图中的所有可训练变量(一个list),保证它和global_vars的顺序对应后,可以这样实现:
def set_global_vars(self, global_vars): """ Assign all of the variables with global vars """ with self.graph.as_default(): all_vars = tf.trainable_variables() for variable, value in zip(all_vars, global_vars): variable.load(value, self.sess)此外,Clients类还需要进行模型定义和训练。我相信这不是实现联邦的重点,因此在下面的代码中,我将函数体去掉只留下接口定义(完整代码在最后一个章节):
import tensorflow as tf import numpy as np from collections import namedtuple import math # 自定义的模型定义函数 from Model import AlexNet # 自定义的数据集类 from Dataset import Dataset # The definition of fed model # 用namedtuple来储存一个模型,依次为: # X: 输入 # Y: 输出 # DROP_RATE: 顾名思义 # train_op: tf计算图中的训练节点(一般是optimizer.minimize(xxx)) # loss_op: 顾名思义 # loss_op: 顾名思义 FedModel = namedtuple('FedModel', 'X Y DROP_RATE train_op loss_op acc_op') class Clients: def __init__(self, input_shape, num_classes, learning_rate, clients_num): self.graph = tf.Graph() self.sess = tf.Session(graph=self.graph) # Call the create function to build the computational graph of AlexNet # `net` 是一个list,依次包含模型中FedModel需要的计算节点(看上面) net = AlexNet(input_shape, num_classes, learning_rate, self.graph) self.model = FedModel(*net) # initialize 初始化 with self.graph.as_default(): self.sess.run(tf.global_variables_initializer()) # Load Cifar-10 dataset # NOTE: len(self.dataset.train) == clients_num # 加载数据集。对于训练集:`self.dataset.train[56]`可以获取56号client的数据集 # `self.dataset.train[56].next_batch(32)`可以获取56号client的一个batch,大小为32 # 对于测试集,所有client共用一个测试集,因此: # `self.dataset.test.next_batch(1000)`将获取大小为1000的数据集(无随机) self.dataset = Dataset(tf.keras.datasets.cifar10.load_data, split=clients_num) def run_test(self, num): """ Predict the testing set, and report the acc and loss 预测测试集,返回准确率和loss num: number of testing instances """ pass def train_epoch(self, cid, batch_size=32, dropout_rate=0.5): """ Train one client with its own data for one epoch 用`cid`号的client的数据对模型进行训练 cid: Client id """ pass def choose_clients(self, ratio=1.0): """ randomly choose some clients 随机选择`ratio`比例的clients,返回编号(也就是下标) """ client_num = self.get_clients_num() choose_num = math.floor(client_num * ratio) return np.random.permutation(client_num)[:choose_num] def get_clients_num(self): """ 返回clients的数量 """ return len(self.dataset.train)细心的同学可能已经发现了,类名是Clients是复数,表示一堆Clients的集合。但模型self.model只有一个,原因是:不同Clients的模型实际上是一样的,只是数据不同;类成员self.dataset已经对数据进行了划分,需要不同client参与训练时,只需要用Server给的变量值把模型变量覆盖掉,再用下标cid找到该Client的数据进行训练就得了。
当然,这样实现的最重要原因,是避免构建那么多个Client的计算图。咱没那么多显存TAT 概括一下:联邦学习的Clients,只是普通TF训练模型代码上,加上模型变量的值提取、赋值功能。
按照套路,明确一下Server端代码的主要任务:
使用Clients:给一组模型变量给某个Client进行更新,把更新后的变量值拿回来管理全局模型:每一轮更新,收集多个Clients更新后的模型进行归总,成为新一轮的模型简单起见,我们Server端的代码不再抽象成一个类,而是以脚本的形式编写。首先,实例化咱们上面定义的Clients:
from Client import Clients def buildClients(num): learning_rate = 0.0001 num_input = 32 # image shape: 32*32 num_input_channel = 3 # image channel: 3 num_classes = 10 # Cifar-10 total classes (0-9 digits) #create Client and model return Clients(input_shape=[None, num_input, num_input, num_input_channel], num_classes=num_classes, learning_rate=learning_rate, clients_num=num) CLIENT_NUMBER = 100 client = buildClients(CLIENT_NUMBER) global_vars = client.get_client_vars()client变量储存着CLIENT_NUMBER个Clients的模型(实际上只有一个计算图)和数据。global_vars储存着Server端的模型变量值,也就是我们大名鼎鼎的训练目标,目前它只是Client端模型初始化的值。
接下来,对于Server的一个epoch,Server会随机挑选一定比例的Clients参与这轮训练,分别把当前的Server端模型global_vars交给它们进行更新,并分别收集它们更新后的变量。本轮参与训练的Clients都收集后,平均一下这些更新后的变量值,就得到新一轮的Server端模型,然后进行下一个epoch。下面是循环epoch更新的代码,仔细看注释哦:
def run_global_test(client, global_vars, test_num): """ 跑一下测试集,输出ACC和Loss """ client.set_global_vars(global_vars) acc, loss = client.run_test(test_num) print("[epoch {}, {} inst] Testing ACC: {:.4f}, Loss: {:.4f}".format( ep + 1, test_num, acc, loss)) CLIENT_RATIO_PER_ROUND = 0.12 # 每轮挑选clients跑跑看的比例 epoch = 360 # epoch上限 for ep in range(epoch): # We are going to sum up active clients' vars at each epoch # 用来收集Clients端的参数,全部叠加起来(节约内存) client_vars_sum = None # Choose some clients that will train on this epoch # 随机挑选一些Clients进行训练 random_clients = client.choose_clients(CLIENT_RATIO_PER_ROUND) # Train with these clients # 用这些Clients进行训练,收集它们更新后的模型 for client_id in tqdm(random_clients, ascii=True): # Restore global vars to client's model # 将Server端的模型加载到Client模型上 client.set_global_vars(global_vars) # train one client # 训练这个下标的Client client.train_epoch(cid=client_id) # obtain current client's vars # 获取当前Client的模型变量值 current_client_vars = client.get_client_vars() # sum it up # 把各个层的参数叠加起来 if client_vars_sum is None: client_vars_sum = current_client_vars else: for cv, ccv in zip(client_vars_sum, current_client_vars): cv += ccv # obtain the avg vars as global vars # 把叠加后的Client端模型变量 除以 本轮参与训练的Clients数量 # 得到平均模型、作为新一轮的Server端模型参数 global_vars = [] for var in client_vars_sum: global_vars.append(var / len(random_clients)) # run test on 1000 instances # 跑一下测试集、输出一下 run_global_test(client, global_vars, test_num=600)经过那么一些轮的迭代,我们就可以得到Server端的训练好的模型参数global_vars了。虽然它逻辑很简单,但我希望观众老爷们能注意到其中的两个联邦点:Server端代码没有接触到数据;每次参与训练的Clients数量相对于整体来说是很少的。
如果要更换模型,只需要实现新的模型计算图构造函数,替换Client端的AlexNet函数,保证它能返回那一系列的计算节点即可。
如果要实现Non-I.I.D.的数据分布,只需要修改Dataset.py中的数据划分方式。但是,我稍微试验了一下,目前这个模型+训练方式,不能应对极度Non-I.I.D.的情况。也反面证明了,Non-I.I.D.确实是联邦学习的一个难题。
如果要Clients和Server之间传模型梯度,需要把Client端的计算梯度和更新变量分开,中间插入和Server端的交互,交互内容就是梯度。这样说有点抽象,很多同学可能经常用Optimizer.minimize(文档在这),但并不知道它是另外两个函数的组合,分别为:compute_gradients()和apply_gradients()。前者是计算梯度,后者是把梯度按照学习率更新到变量上。把梯度拿到后,交给Server,Server返回一个全局平均后的梯度再更新模型。尝试过是可行的,但是并不能减少传输量,而且单机模拟实现难度大了许多。这一部分请参考另一篇博文:https://blog.csdn.net/Mr_Zing/article/details/109496824。
如果要分布式部署,那就把Clients端代码放在flask等web后端服务下进行部署,Server端通过网络传输与Clients进行通信。需要注意,Server端发起请求的时候,可能因为参数量太大导致一些问题,考虑换个非HTTP协议。
一共有四个代码文件,他们应当放在同一个文件目录下:
Client.py:Client端代码,管理模型、数据Server.py:Server端代码,管理Clients、全局模型Dataset.py:定义数据的组织形式Model.py:定义TF模型的计算图我也将它们传到了Github上,仓库链接:https://github.com/Zing22/tf-fed-demo。下面开始分别贴出它们的完整代码,其中的注释只有我边打码边写的一点点,上文的介绍中补充了更多中文注释。运行方法非常简单:
python Server.py