fedjax core
FedJAX API.
Subpackages
Federated algorithm
Container for all federated algorithms. |
Federated data
FederatedData interface for providing access to a federated dataset. |
|
A simple wrapper over a concrete FederatedData for restricting to a subset of client ids. |
|
Federated dataset backed by SQLite. |
|
A simple wrapper over a concrete fedjax.FederatedData for small in memory datasets. |
|
FederatedDataBuilder interface. |
|
Builds SQLite files from a python dictionary containing an arbitrary mapping of client IDs to NumPy examples. |
|
A chain of preprocessing functions on all examples of a client dataset. |
|
Shuffle-repeat-batch all client datasets in a federated dataset for training a centralized baseline. |
|
Padded batch all client datasets, useful for evaluation on the entire federated dataset. |
Client dataset
In memory client dataset backed by numpy ndarrays. |
|
A chain of preprocessing functions on batched examples. |
|
Shuffles and batches examples from multiple client datasets. |
|
Batches examples from multiple client datasets. |
For each client
Creates a function which maps over clients. |
|
A context manager for switching to a given ForEachClientBackend in the current thread. |
|
Sets the for_each_client backend for the current thread. |
Model
Container class for models. |
|
Creates Model after applying defaults and haiku specific preprocessing. |
|
Creates Model after applying defaults and stax specific preprocessing. |
|
Evaluates model for multiple batches and returns final results. |
|
A standard gradient function derived from a model and an optional regularizer. |
|
Convenience function for constructing a per-example loss function from a model. |
|
Evaluates the average per example loss over multiple batches. |
|
Evaluates model for each client dataset, either using global params, or per client params. |
|
Evaluates average loss for each client dataset, either using global params, or per client params. |
|
A standard gradient function derived from per-example loss and an optional regularizer. |
- class fedjax.FederatedAlgorithm(init, apply)[source]
Container for all federated algorithms.
FederatedAlgorithm defines the required methods that need to be implemented by all custom federated algorithms. Defining federated algorithms in this structure will allow implementations to work seamlessly with the convenience methods in the
fedjax.training
API, like checkpointing.Example toy implementation:
# Federated algorithm that just counts total number of points across clients # across rounds. def count_federated_algorithm(): def init(init_count): return {'count': init_count} def apply(state, clients): count = 0 client_diagnostics = {} # Count sizes across clients. for client_id, client_dataset, _ in clients: # Summation across clients in one round. client_count = len(client_dataset) count += client_count client_diagnostics[client_id] = client_count # Summation across rounds. state = {'count': state['count'] + count} return state, client_diagnostics return FederatedAlgorithm(init, apply) rng = jax.random.PRNGKey(0) # Unused. all_clients = [ [ (b'cid0', ClientDataset({'x': jnp.array([1, 2, 1, 2, 3, 4])}), rng), (b'cid1', ClientDataset({'x': jnp.array([1, 2, 3, 4, 5])}), rng), (b'cid2', ClientDataset({'x': jnp.array([1, 1, 2])}), rng), ], [ (b'cid3', ClientDataset({'x': jnp.array([1, 2, 3, 4])}), rng), (b'cid4', ClientDataset({'x': jnp.array([1, 1, 2, 1, 2, 3])}), rng), (b'cid5', ClientDataset({'x': jnp.array([1, 2, 3, 4, 5, 6, 7])}), rng), ], ] algorithm = count_federated_algorithm() state = algorithm.init(0) for round_num in range(2): state, client_diagnostics = algorithm.apply(state, all_clients[round_num]) print(round_num, state) print(round_num, client_diagnostics) # 0 {'count': 14} # 0 {b'cid0': 6, b'cid1': 5, b'cid2': 3} # 1 {'count': 31} # 1 {b'cid3': 4, b'cid4': 6, b'cid5': 7}
- init
Initializes the
ServerState
. Typically, the input to this method will be the initial modelParams
. This should only be run once at the beginning of training.- Type:
Callable[[…], Any]
- apply
Completes one round of federated training given an input
ServerState
and a sequence of tuples of client identifier, client dataset, and client rng. The output will be a new, updatedServerState
and accumulated per step results keyed by client identifier (e.g. train metrics).- Type:
Callable[[Any, Sequence[Tuple[bytes, fedjax.core.client_datasets.ClientDataset, jax.Array]]], Tuple[Any, Mapping[bytes, Any]]]
- class fedjax.FederatedData[source]
FederatedData interface for providing access to a federated dataset.
A FederatedData object serves as a mapping from client ids to client datasets and client metadata.
Access methods with better I/O efficiency
For large federated datasets, it is not feasible to load all client datasets into memory at once (whereas loading a single client dataset is assumed to be feasible). Different implementations exist for different on disk storage formats. Since sequential read is much faster than random read for most storage technologies, FederatedData provides two types of methods for accessing client datasets,
clients()
andshuffled_clients()
are sequential read friendly, and thus recommended whenever appropriate.get_clients()
requires random read, but prefetching is possible. This should be preferred overget_client()
.get_client()
is usually the slowest way of accessing client datasets, and is mostly intended for interactive exploration of a small number of clients.
Preprocessing
ClientDataset
produced by FederatedData can hold aBatchPreprocessor
, customizable viapreprocess_batch()
. Additionally, another “client” levelClientPreprocessor
, customizable viapreprocess_client()
, can be used to apply transformations on examples from the entire client dataset before aClientDataset
is constructed.- abstract client_ids()[source]
Returns an iterator of client ids as bytes.
There is no requirement on the order of iteration.
- Return type:
Iterator
[bytes
]
- abstract client_size(client_id)[source]
Returns the number of examples in a client dataset.
- Return type:
int
- abstract client_sizes()[source]
Returns an iterator of all (client id, client size) pairs.
This is often more efficient than making multiple
client_size()
calls. There is no requirement on the order of iteration.- Return type:
Iterator
[Tuple
[bytes
,int
]]
- abstract clients()[source]
Iterates over clients in a deterministic order.
Implementation can choose whatever order that makes iteration efficient.
- Return type:
Iterator
[Tuple
[bytes
,ClientDataset
]]
- abstract get_client(client_id)[source]
Gets one single client dataset.
Prefer
clients()
,shuffled_clients()
, orget_clients()
when possible.- Parameters:
client_id (
bytes
) – Client id to load.- Return type:
- Returns:
The corresponding ClientDataset.
- abstract get_clients(client_ids)[source]
Gets multiple clients in order with one call.
Clients are returned in the order of
client_ids
.- Parameters:
client_ids (
Iterable
[bytes
]) – Client ids to load.- Return type:
Iterator
[Tuple
[bytes
,ClientDataset
]]- Returns:
Iterator.
- abstract num_clients()[source]
Returns the number of clients.
If it is too expensive or otherwise impossible to obtain the result, an implementation may raise an exception.
- Return type:
int
- abstract preprocess_batch(fn)[source]
Registers a preprocessing function to be called after batching in ClientDatasets.
- Return type:
- abstract preprocess_client(fn)[source]
Registers a preprocessing function to be called on all examples of a client before passing them to construct a ClientDataset.
- Return type:
- abstract shuffled_clients(buffer_size, seed=None)[source]
Iterates over clients with a repeated buffered shuffling.
Shuffling should use a buffer size of at least
buffer_size
clients. The iteration should repeat forever, with usually a different order in each pass.- Parameters:
buffer_size (
int
) – Buffer size for shuffling.seed (
Optional
[int
]) – Optional random number generator seed.
- Return type:
Iterator
[Tuple
[bytes
,ClientDataset
]]- Returns:
Iterator.
- abstract slice(start=None, stop=None)[source]
Returns a new FederatedData restricted to client ids in the given range.
The returned FederatedData includes clients whose ids are,
Greater than or equal to
start
whenstart
is not None;Less than
stop
whenstop
is not None.
- Parameters:
start (
Optional
[bytes
]) – Start of client id range.stop (
Optional
[bytes
]) – Stop of client id range.
- Return type:
- Returns:
FederatedData.
- class fedjax.SubsetFederatedData(base, client_ids, validate=True)[source]
Bases:
FederatedData
A simple wrapper over a concrete FederatedData for restricting to a subset of client ids.
This is useful when we wish to create a smaller FederatedData out of arbitrary client ids, where slicing is not possible.
- __init__(base, client_ids, validate=True)[source]
Initializes the subset federated dataset.
- Parameters:
base (
FederatedData
) – Base concrete FederatedData.client_ids (
Iterable
[bytes
]) – Client ids to include in the subset. All client ids must be in base.client_ids(), otherwise the behavior of SubsetFederatedData is undefined when validate=False.validate – Whether to validate client ids.
- class fedjax.SQLiteFederatedData(connection, parse_examples, start=None, stop=None, preprocess_client=ClientPreprocessor(()), preprocess_batch=BatchPreprocessor(()))[source]
Bases:
FederatedData
Federated dataset backed by SQLite.
The SQLite database should contain a table named “federated_data” created with the following command:
CREATE TABLE federated_data ( client_id BLOB NOT NULL PRIMARY KEY, data BLOB NOT NULL, num_examples INTEGER NOT NULL );
where,
client_id is the bytes client id.
data is the serialized client dataset examples.
num_examples is the number of examples in the client dataset.
By default we use zlib compressed msgpack blobs for data (see decompress_and_deserialize()).
- __init__(connection, parse_examples, start=None, stop=None, preprocess_client=ClientPreprocessor(()), preprocess_batch=BatchPreprocessor(()))[source]
- static new(path, parse_examples=<function decompress_and_deserialize>)[source]
Opens a federated dataset stored as an SQLite3 database.
- Parameters:
path (
str
) – Path to the SQLite database file.parse_examples (
Callable
[[bytes
],Mapping
[str
,ndarray
]]) – Function for deserializing client dataset examples.
- Return type:
- Returns:
SQLite3DataSource.
- class fedjax.InMemoryFederatedData(client_to_data_mapping, preprocess_client=ClientPreprocessor(()), preprocess_batch=BatchPreprocessor(()))[source]
Bases:
FederatedData
A simple wrapper over a concrete fedjax.FederatedData for small in memory datasets.
This is useful when we wish to create a smaller FederatedData that fits in memory. Here is a simple example to create a fedjax.InMemoryFederatedData,
client_a_data = { 'x': np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]), 'y': np.array([7, 8]) } client_b_data = {'x': np.array([[9.0, 10.0, 11.0]]), 'y': np.array([12])} client_to_data_mapping = {'a': client_a_data, 'b': client_b_data} fedjax.InMemoryFederatedData(client_to_data_mapping)
- Returns:
A fedjax.InMemoryDataset corresponding to client_to_data_mapping.
- __init__(client_to_data_mapping, preprocess_client=ClientPreprocessor(()), preprocess_batch=BatchPreprocessor(()))[source]
Initializes the in memory federated dataset.
Data of each client is a mapping from feature names to numpy arrays. For example, for emnist image classification, {‘x’: X, ‘y’: y}, where X is a matrix of shape (num_data_points, 28, 28) and y is a matrix of shape (num_data_points).
- Parameters:
client_to_data_mapping (
Mapping
[bytes
,Mapping
[str
,ndarray
]]) – A mapping from client_id to data of each client.preprocess_client (
ClientPreprocessor
) – federated_data.ClientPreprocessor to preprocess each client data.preprocess_batch (
BatchPreprocessor
) – client_datasets.BatchPreprocessor to preprocess batch of data.
- class fedjax.FederatedDataBuilder[source]
FederatedDataBuilder interface.
To be implemented as a context manager for building file formats from pairs of client IDs and client NumPy examples.
It is relevant to note that the add method below does not specify any raised exceptions. One could imagine some formats where add can fail in some way: out-of-order or duplicate inputs, remote files and network failures, individual entries too big for a format, etc. In order to address this we let implementations throw whatever they see relevant and fit to their particular use cases. The same is relevant when it comes to the __init__, __enter__, and __exit__ methods, where implementations are left with the responsibility of raising exceptions as they see fit to their particular use cases. For example, if an invalid file path is passed, or there were any issues finalizing the builder, etc.
Eg of end behavior:
with FederatedDataBuilder(path) as builder: builder.add(b'k1', np.array([b'v1'], dtype=np.object)) builder.add(b'k2', np.array([b'v2'], dtype=np.object))
- class fedjax.SQLiteFederatedDataBuilder(path)[source]
Bases:
FederatedDataBuilder
Builds SQLite files from a python dictionary containing an arbitrary mapping of client IDs to NumPy examples.
- class fedjax.ClientPreprocessor(fns=())[source]
A chain of preprocessing functions on all examples of a client dataset.
This is very similar to
fedjax.BatchPreprocessor
, with the main difference being that ClientPreprocessor also takesclient_id
as input.See the discussion in
fedjax.BatchPreprocessor
regarding when to use which.
- fedjax.shuffle_repeat_batch_federated_data(fd, batch_size, client_buffer_size, example_buffer_size, seed=None)[source]
Shuffle-repeat-batch all client datasets in a federated dataset for training a centralized baseline.
Shuffling is done using two levels of buffered shuffling, first at the client level, then at the example level.
This produces an infinite stream of batches. itertools.islice() can be used to cap the number of batches, if so desired.
- Parameters:
fd (
FederatedData
) – Federated dataset.batch_size (
int
) – Desired batch size.client_buffer_size (
int
) – Buffer size for client level shuffling.example_buffer_size (
int
) – Buffer size for example level shuffling.seed (
Optional
[int
]) – Optional RNG seed.
- Yields:
Batches of preprocessed examples.
- Return type:
Iterator
[Mapping
[str
,ndarray
]]
- fedjax.padded_batch_federated_data(fd, hparams=None, **kwargs)[source]
Padded batch all client datasets, useful for evaluation on the entire federated dataset.
- Parameters:
fd (
FederatedData
) – Federated dataset.hparams (
Optional
[PaddedBatchHParams
]) – Seefedjax.padded_batch_client_datasets()
.**kwargs – See
fedjax.padded_batch_client_datasets()
.
- Yields:
Batches of preprocessed examples.
- Return type:
Iterator
[Mapping
[str
,ndarray
]]
- class fedjax.RepeatableIterator(base)[source]
Repeats a base iterable after the end of the first pass of iteration.
Because this is a stateful object, it is not thread safe, and all usual caveats about accessing the same iterator at different locations apply. For example, if we make two map calls to the same RepeatableIterator, we must make sure we do not interleave next() calls on these. For example, the following is safe because we finish iterating on m1 before starting to iterate on m2.,
it = RepeatableIterator(range(4)) m1 = map(lambda x: x + 1, it) m2 = map(lambda x: x * x, it) # We finish iterating on m1 before starting to iterate on m2. print(list(m1), list(m2)) # [1, 2, 3, 4] [0, 1, 4, 9]
Whereas interleaved access leads to confusing results,
it = RepeatableIterator(range(4)) m1 = map(lambda x: x + 1, it) m2 = map(lambda x: x * x, it) print(next(m1), next(m2)) # 1 1 print(next(m1), next(m2)) # 3 9 print(next(m1), next(m2)) # StopIteration!
In the first pass of iteration, values fetched from the base iterator will be copied onto an internal buffer (except for a few builtin containers where copying is unnecessary). When each pass of iteration finishes (i.e. when __next__() raises StopIteration), the iterator resets itself to the start of the buffer, thus allowing a subsequent pass of repeated iteration.
In most cases, if repeated iterations are required, it is sufficient to simply copy values from an iterator into a list. However, sometimes an iterator produces values via potentially expensive I/O operations (e.g. loading client datasets), RepeatableIterator can interleave I/O and JAX compute to decrease accelerator idle time in this case.
Preprocessing and batching operations over client datasets.
Column based representation
The examples in a client dataset can be viewed as a table, where the rows are the individual examples, and the columns are the features (labels are viewed as a feature in this context).
We use a column based representation when loading a dataset into memory.
Each column is a NumPy array
x
of rank at least 1, wherex[i, ...]
is the value of this feature for thei
-th example.The complete set of examples is a dict-like object, from
str
feature names, to the corresponding column values.
Traditionally, a row based representation is used for representing the entire dataset, and a column based representation is used for a single batch. In the context of federated learning, an individual client dataset is small enough to easily fit into memory so the same representation is used for the entire dataset and a batch.
Preprocessor
Preprocessing on a batch of examples can be easily done via a chain of
functions. A Preprocessor
object holds the chain of functions, and applies
the transformation on a batch of examples.
ClientDataset: examples + preprocessor
A ClientDataset
is simply some examples in the column based
representation, accompanied by a Preprocessor.
Its batch()
method produces batches of examples in a
sequential order, suitable for evaluation.
Its shuffle_repeat_batch()
method adds shuffling and
repeating, making it suitable for training.
- class fedjax.ClientDataset(raw_examples, preprocessor=BatchPreprocessor(()))[source]
In memory client dataset backed by numpy ndarrays.
Custom preprocessing on batches can be added via a preprocessor. A ClientDataset is stored as the unpreprocessed
raw_examples
, along with its preprocessor.To access batches, use one of the batching functions (e.g.
shuffle_repeat_batch()
for training,padded_batch()
for evaluation).To access a small number of preprocessed examples (e.g. for exploration), use slicing +
all_examples()
.
This is only intended for efficient access to small datasets that fit in memory.
- all_examples()[source]
Returns the result of feeding all raw examples through the preprocessor.
This is mostly intended for interactive exploration of a small subset of a client dataset. For example, to see the first 4 examples in a client dataset,
dataset = ClientDataset(my_raw_examples, my_preprocessor) dataset[:4].all_examples()
- Return type:
Mapping
[str
,ndarray
]- Returns:
Preprocessed examples from all the raw examples in this client dataset.
- batch(hparams=None, **kwargs)[source]
Produces preprocessed batches in a fixed sequential order.
The final batch may contain fewer than
batch_size
examples. If used directly, that may result in a large number of JIT recompilations. Therefore we recommended usingpadded_batch
instead in most scenarios.This function can be invoked in 2 ways:
Using a hyperparams object. This is the recommended way in library code if you have to use batch (prefer padded_batch() if possible). Example:
def a_library_function(client_dataset, hparams): for batch in client_dataset.batch(hparams): ...
Using keyword arguments. The keyword arguments are used to construct a new hyperparams object, or override an existing one. For example,
client_dataset.batch(batch_size=2) # Overrides the default drop_remainder value. client_dataset.batch(hparams, drop_remainder=True)
- Parameters:
hparams (
Optional
[BatchHParams
]) – Batching hyperparameters.**kwargs – Keyword arguments for constructing/overriding hparams.
- Return type:
Iterable
[Mapping
[str
,ndarray
]]- Returns:
An iterable object that can be iterated over multiple times.
- padded_batch(hparams=None, **kwargs)[source]
Produces preprocessed padded batches in a fixed sequential order.
This function can be invoked in 2 ways:
Using a hyperparams object. This is the recommended way in library code. Example:
def a_library_function(client_dataset, hparams): for batch in client_dataset.padded_batch(hparams): ...
Using keyword arguments. The keyword arguments are used to construct a new hyperparams object, or override an existing one. For example,
client_dataset.padded_batch(batch_size=2) # Overrides the default num_batch_size_buckets value. client_dataset.padded_batch(hparams, num_batch_size_buckets=2)
When the number of examples in the dataset is not a multiple of
batch_size
, the final batch may be smaller thanbatch_size
. This may lead to a large number of JIT recompilations. This can be circumvented by padding the final batch to a small number of fixed sizes controlled bynum_batch_size_buckets
.All batches contain an extra bool feature keyed by
EXAMPLE_MASK_KEY
.batch[EXAMPLE_MASK_KEY][i]
tells us whether thei
-th example in this batch is an actual example (batch[EXAMPLE_MASK_KEY][i] == True
), or a padding example (batch[EXAMPLE_MASK_KEY][i] == False
).We repeatedly halve the batch size up to
num_batch_size_buckets-1
times, until we find the smallest one that is also >= the size of the final batch. Therefore ifbatch_size < 2^num_batch_size_buckets
, fewer bucket sizes will be actually used.- Parameters:
hparams (
Optional
[PaddedBatchHParams
]) – Batching hyperparameters.**kwargs – Keyword arguments for constructing/overriding hparams.
- Return type:
Iterable
[Mapping
[str
,ndarray
]]- Returns:
An iterable object that can be iterated over multiple times.
- shuffle_repeat_batch(hparams=None, **kwargs)[source]
Produces preprocessed batches in a shuffled and repeated order.
This function can be invoked in 2 ways:
Using a hyperparams object. This is the recommended way in library code. Example:
def a_library_function(client_dataset, hparams): for batch in client_dataset.shuffle_repeat_batch(hparams): ...
Using keyword arguments. The keyword arguments are used to construct a new hyperparams object, or override an existing one. For example,
client_dataset.shuffle_repeat_batch(batch_size=2) # Overrides the default num_epochs value. client_dataset.shuffle_repeat_batch(hparams, num_epochs=2)
Shuffling is done without replacement, therefore for a dataset of N examples, the first
ceil(N/batch_size)
batches are guarranteed to cover the entire dataset.By default the iteration stops after the first epoch. The number of batches produced from the iteration can be controlled by the
(num_epochs, num_steps, drop_remainder)
combination:If both
num_epochs
andnum_steps
are None, the shuffle-repeat process continues forever.If
num_epochs
is set andnum_steps
is None, as few batches as needed to go over the dataset this many passes are produced. Further,If
drop_remainder
is False (the default), the final batch is filled with additionally sampled examples to containbatch_size
examples.If
drop_remainder
is True, the final batch is dropped if it contains fewer thanbatch_size
examples. This may result in examples being skipped whennum_epochs=1
.
If
num_steps
is set andnum_steps
is None, exactly this many batches are produced.drop_remainder
has no effect in this case.If both
num_epochs
andnum_steps
are set, the fewer number of batches between the two conditions are produced.
If reproducible iteration order is desired, a fixed
seed
can be used. Whenseed
is None, repeated iteration over the same object may produce batches in a different order.Unlike
batch()
orpadded_batch()
, batches fromshuffle_repeat_batch()
always contain exactlybatch_size
examples. Also unlike TensorFlow, that holds even whendrop_remainder=False
.- Parameters:
hparams (
Optional
[ShuffleRepeatBatchHParams
]) – Batching hyperparamters.**kwargs – Keyword arguments for constructing/overriding hparams.
- Return type:
Iterable
[Mapping
[str
,ndarray
]]- Returns:
An iterable object that can be iterated over multiple times.
- class fedjax.BatchPreprocessor(fns=())[source]
A chain of preprocessing functions on batched examples.
BatchPreprocessor holds a chain of preprocessing functions, and applies them in order on batched examples. Each individual preprocessing function operates over multiple examples, instead of just 1 example. For example,
preprocessor = BatchPreprocessor([ # Flattens `pixels`. lambda x: {**x, 'pixels': x['pixels'].reshape([-1, 28 * 28])}, # Introduce `binary_label`. lambda x: {**x, 'binary_label': x['label'] % 2}, ]) fake_emnist = { 'pixels': np.random.uniform(size=(10, 28, 28)), 'label': np.random.randint(10, size=(10,)) } preprocessor(fake_emnist) # Produces a dict of [10, 28*28] "pixels", [10,] "label" and "binary_label".
Given a BatchPreprocessor, a new BatchPreprocessor can be created with an additional preprocessing function appended to the chain,
# Continuing from the previous example. new_preprocessor = preprocessor.append( lambda x: {**x, 'sum_pixels': np.sum(x['pixels'], axis=1)}) new_preprocessor(fake_emnist) # Produces a dict of [10, 28*28] "pixels", [10,] "sum_pixels", "label" and # "binary_label".
The main difference of this preprocessor and
fedjax.ClientPreprocessor
is thatfedjax.ClientPreprocessor
also takesclient_id
as input. Because of the identical representation between batched examples and all examples in a client dataset, certain preprocessing can be done with either BatchPreprocessor or ClientPreprocessor.Examples of preprocessing possible at either the client dataset level, or the batch level
Such preprocessing is deterministic, and strictly per-example.
Casting a feature from int8 to float32.
Adding a new feature derived from existing features.
Remove a feature (although the better place to do so is at the dataset level).
A simple rule for deciding where to carry out the preprocessing in this case is the following,
Does this make batching cheaper (e.g. removing features)? If so, do it at the dataset level.
Otherwise, do it at the batch level.
Assuming preprocessing time is linear in the number of examples, preprocessing at the batch level has the benefit of evenly distributing host compute work, which may overlap better with asynchronous JAX compute work on GPU/TPU.
Examples of preprocessing only possible at the batch level
Data augmentation (e.g. random cropping).
Padding at the batch size dimension.
Examples of preprocessing only possible at the dataset level
Those that require knowing the client id.
Capping the number of examples.
Altering what it means to be an example: e.g. in certain language model setups, sentences are concatenated and then split into equal sized chunks.
- fedjax.buffered_shuffle_batch_client_datasets(datasets, batch_size, buffer_size, rng)[source]
Shuffles and batches examples from multiple client datasets.
This just makes 1 pass over the examples. To achieve repeated iterations, create an infinite shuffled stream of datasets first (e.g. using buffered_shuffle()).
- Parameters:
datasets (
Iterable
[ClientDataset
]) – ClientDatasets to be batched. All ClientDatasets must have the same Preprocessor object attached.batch_size (
int
) – Desired batch size.buffer_size (
int
) – Number of examples to buffer during shuffling.rng (
RandomState
) – Source of randomness.
- Yields:
Batches of examples. For a finite stream of datasets, the final batch might be smaller than
batch_size
.- Raises:
ValueError – If any 2 client datasets have different Preprocessors.
ValueError – If any 2 client datasets have different features.
- Return type:
Iterator
[Mapping
[str
,ndarray
]]
- fedjax.padded_batch_client_datasets(datasets, hparams=None, **kwargs)[source]
Batches examples from multiple client datasets.
This is useful when we want to evaluate on the combined dataset consisting of multiple client datasets. Unlike batching each client dataset individually, we can reduce the number of batches smaller than
batch_size
.This function can be invoked in 2 ways:
Using a hyperparams object. This is the recommended way in library code. Example:
def a_library_function(datasets, hparams): for batch in padded_batch_client_datasets(datasets, hparams): ...
Using keyword arguments. The keyword arguments are used to construct a new hyperparams object, or override an existing one. For example,
padded_batch_client_datasets(datasets, hparams) # Overrides the default num_batch_size_buckets value. padded_batch_client_datasets(datasets, hparams, num_batch_size_buckets=2)
- Parameters:
datasets (
Iterable
[ClientDataset
]) – ClientDatasets to be batched. All ClientDatasets must have the same Preprocessor object attached.hparams (
Optional
[PaddedBatchHParams
]) – Batching hyperparams like those inClientDataset.padded_batch()
.**kwargs – Keyword arguments for constructing/overriding hparams.
- Yields:
Batches of examples. The final batch might be padded. All batches contain a bool feature keyed by EXAMPLE_MASK_KEY.
- Raises:
ValueError – If any 2 client datasets have different Preprocessors.
ValueError – If any 2 client datasets have different features.
- Return type:
Iterator
[Mapping
[str
,ndarray
]]
- fedjax.for_each_client(client_init, client_step, client_final=<function <lambda>>, with_step_result=False)[source]
Creates a function which maps over clients.
For example, for_each_client could be used to define how to run client updates for each client in a federated training round. Another common use case of for_each_client is to run evaluation per client for a given set of model parameters.
The underlying backend for for_each_client is customizable. For example, if multiple devies are available (e.g. TPU), a
jax.pmap()
based backend can be used to parallelize across devices. It’s also possible to manually specify which backend to use (for debugging).The expected usage of for_each_client is as follows:
# Map over clients and count how many points are greater than `limit` for # each client. Each client also has a different `start` that is specified # via client input. def client_init(shared_input, client_input): client_step_state = { 'limit': shared_input['limit'], 'count': client_input['start'] } return client_step_state def client_step(client_step_state, batch): num = jnp.sum(batch['x'] > client_step_state['limit']) client_step_state = { 'limit': client_step_state['limit'], 'count': client_step_state['count'] + num } return client_step_state def client_final(shared_input, client_step_state): del shared_input # Unused. return client_step_state['count'] # Three clients with different data and starting counts. # clients = [(client_id, client_batches, client_input)] clients = [ (b'cid0', [{'x': jnp.array([1, 2, 3, 4])}, {'x': jnp.array([1, 2, 3])}], {'start': jnp.array(2)}), (b'cid1', [{'x': jnp.array([1, 2])}, {'x': jnp.array([1, 2, 3, 4, 5])}], {'start': jnp.array(0)}), (b'cid2', [{'x': jnp.array([1])}], {'start': jnp.array(1)}), ] shared_input = {'limit': jnp.array(2)} func = fedjax.for_each_client(client_init, client_step, client_final) print(list(func(shared_input, clients))) # [(b'cid0', 5), (b'cid1', 3), (b'cid2', 1)]
Here’s the same example with per step results.
# We'll also keep track of the `num` per step in our step results. def client_step_with_result(client_step_state, batch): num = jnp.sum(batch['x'] > client_step_state['limit']) client_step_state = { 'limit': client_step_state['limit'], 'count': client_step_state['count'] + num } client_step_result = {'num': num} return client_step_state, client_step_result func = fedjax.for_each_client( client_init, client_step_with_result, client_final, with_step_result=True) print(list(func(shared_input, clients))) # [ # (b'cid0', 5, [{'num': 2}, {'num': 1}]), # (b'cid1', 3, [{'num': 0}, {'num': 3}]), # (b'cid2', 1, [{'num': 0}]), # ]
- Parameters:
client_init (
Callable
[[Any
,Any
],Any
]) – Function that initializes the internal intermittent client step state from the share input and per client input. The shared input contains information like the global model parameters that are shared across all clients. The per client input is per client information. The initialized internal client step state is fed as intermittent input and output from client_step and client_final. This client step state usually contains the model parameters and optimizer state for each client that are updated at each client_step. This will be run once for each client.client_step (
Union
[Callable
[[Any
,Mapping
[str
,Array
]],Tuple
[Any
,Any
]],Callable
[[Any
,Mapping
[str
,Array
]],Any
]]) – Function that takes the client step state and a batch of examples as input and outputs a (possibly updated) client step state. Optionally, per step results can also be returned as the second element if with_step_result is True. Per step results are usually diagnostics like gradient norm. This will be run for each batch for each client.client_final (
Callable
[[Any
,Any
],Any
]) – Function that applies the final transformation on the internal client step state to the desired final client output. More meaningful transformations can be done here, like model update clipping. Defaults to just returning the client step state. This will be run once for each client.with_step_result (
bool
) – Indicates whether client_step returns a pair where the first element is considered the client output and the second element is the client step result.
- Returns:
A for each client function that takes shared_input and the per client inputs as tuple (client_id, batched_client_data, client_rng) to map over and returns the outputs per client as specified in client_final along with optional per client per step results.
- fedjax.for_each_client_backend(backend)[source]
A context manager for switching to a given ForEachClientBackend in the current thread.
Example:
with for_each_client_backend('pmap'): # We will be using the pmap based for_each_client backend within this block. pass # We will be using the default for_each_client backend from now on.
- Parameters:
backend (
Union
[ForEachClientBackend
,str
,None
]) – Seeset_for_each_client_backend()
.- Yields:
Nothing.
- fedjax.set_for_each_client_backend(backend)[source]
Sets the for_each_client backend for the current thread.
- Parameters:
backend (
Union
[ForEachClientBackend
,str
,None
]) –One of the following,
None: uses the default backend for the current environment.
’debug’: uses the debugging backend.
’jit’: uses the JIT backend.
’pmap’: uses the pmap-based backend.
A concrete ForEachClientBackend object.
- class fedjax.Model(init, apply_for_train, apply_for_eval, train_loss, eval_metrics)[source]
Container class for models.
Model exists to provide easy access to predefined neural network models. It is meant to contain all the information needed for standard centralized training and evaluation. Non-standard training methods can be built upon the information avaiable in Model along with any additional information (e.g. interpolation can be implemented as a composition of two models along with an interpolation weight).
Works for Haiku and jax.example_libraries.stax.
The expected usage of Model is as follows:
# Training. step_size = 0.1 rng = jax.random.PRNGKey(0) params = model.init(rng) def loss(params, batch, rng): preds = model.apply_for_train(params, batch, rng) return jnp.sum(model.train_loss(batch, preds)) grad_fn = jax.grad(loss) for batch in batches: rng, use_rng = jax.random.split(rng) grads = grad_fn(params, batch, use_rng) params = jax.tree_util.tree_map(lambda a, b: a - step_size * b, params, grads) # Evaluation. print(fedjax.evaluate_model(model, params, batches)) # Example output: # {'loss': 2.3, 'accuracy': 0.2}
The following is an example using Model compositionally as a building block to impelement model interpolation:
def interpolate(model_1, model_2, init_weight): @jax.jit def init(rng): rng_1, rng_2 = jax.random.split(rng) params_1 = model_1.init(rng_1) params_2 = model_2.init(rng_2) return params_1, params_2, init_weight @jax.jit def apply_for_train(params, input, rng): rng_1, rng_2 = jax.random.split(rng) params_1, params_2, weight = params return (model_1.apply_for_train(params_1, input, rng_1) * weight + model_2.apply_for_train(params_1, input, rng_2) * (1 - weight)) @jax.jit def apply_for_eval(params, input): params_1, params_2, weight = params return (model_1.apply_for_eval(params_1, input) * weight + model_2.apply_for_eval(params_2, input) * (1 - weight)) return fedjax.Model(init, apply_for_train, apply_for_eval, model_1.train_loss, model_1.eval_metrics) model = interpolate(model_1, model_2, init_weight=0.5)
- init
Initialization function that takes a seed PRNGKey and returns a PyTree of initialized parameters (i.e. model weights). These parameters will be passed as input into
apply_for_train()
andapply_for_eval()
. Any trainable weights for a model that are modified in the training loop should be contained inside of these parameters.- Type:
Callable[[jax.Array], Any]
- apply_for_train
Function that takes the parameters PyTree, batch of examples, and PRNGKey as inputs and outputs the model predictions for training that are then passed into
train_loss()
. This considers strategies such as dropout.
- apply_for_eval
Function that usually takes the parameters PyTree and batch of examples as inputs and outputs the model predictions for evaluation that are then passed to
eval_metrics
. This is defined separately fromapply_for_train()
to avoid having to specify inputs like PRNGKey that are not used in evaluation.
- train_loss
Loss function for training that takes batch of examples and model output from
apply_for_train()
as input that outputs per example loss. This will typically called inside ajax.grad()
wrapped function to compute gradients.
- eval_metrics
Ordered mapping of evaluation metric names to
Metric
. TheseMetric
s are defined for single examples and will be used inevaluate_model()
- Type:
Mapping[str, fedjax.core.metrics.Metric]
- fedjax.create_model_from_haiku(transformed_forward_pass, sample_batch, train_loss, eval_metrics=None, train_kwargs=None, eval_kwargs=None)[source]
Creates Model after applying defaults and haiku specific preprocessing.
- Parameters:
transformed_forward_pass (
Transformed
) – Transformed forward pass fromhk.transform()
sample_batch (
Mapping
[str
,Array
]) – Example input used to determine model parameter shapes.train_loss (
Callable
[[Mapping
[str
,Array
],Array
],Array
]) – Loss function for training that outputs per example loss.eval_metrics (
Optional
[Mapping
[str
,Metric
]]) – Mapping of evaluation metric names toMetric
. These metrics are defined for single examples and will be consumed inevaluate_model()
.train_kwargs (
Optional
[Mapping
[str
,Any
]]) – Keyword arguments passed to model for training.eval_kwargs (
Optional
[Mapping
[str
,Any
]]) – Keyword arguments passed to model for evaluation.
- Return type:
- Returns:
Model
- fedjax.create_model_from_stax(stax_init, stax_apply, sample_shape, train_loss, eval_metrics=None, train_kwargs=None, eval_kwargs=None, input_key='x')[source]
Creates Model after applying defaults and stax specific preprocessing.
- Parameters:
stax_init (
Callable
[…,Any
]) – Initialization function returned fromstax.serial()
.stax_apply (
Callable
[…,Array
]) – Model forward_pass pass function returned from stax.serial.sample_shape (
Tuple
[int
, …]) – The expected shape of the input to the model.train_loss (
Callable
[[Mapping
[str
,Array
],Array
],Array
]) – Loss function for training that outputs per example loss.eval_metrics (
Optional
[Mapping
[str
,Metric
]]) – Mapping of evaluation metric names toMetric
. These metrics are defined for single examples and will be consumed inevaluate_model()
.train_kwargs (
Optional
[Mapping
[str
,Any
]]) – Keyword arguments passed to model for training.eval_kwargs (
Optional
[Mapping
[str
,Any
]]) – Keyword arguments passed to model for evaluation.input_key (
str
) – Key name for the input in batch mapping.
- Return type:
- Returns:
Model
- fedjax.evaluate_model(model, params, batches)[source]
Evaluates model for multiple batches and returns final results.
This is the recommended way to compute evaluation metrics for a given model.
- Parameters:
model (
Model
) – Model container.params (
Any
) – Pytree of model parameters to be evaluated.batches (
Iterable
[Mapping
[str
,Array
]]) – Multiple batches to compute and aggregate evaluation metrics over. Each batch can optional contain a feature keyed by client_datasets.MASK_KEY (seeClientDataset.padded_batch()
).
- Return type:
Dict
[str
,Array
]- Returns:
A dictionary of evaluation
Metric
results.
- fedjax.model_grad(model, regularizer=None)[source]
A standard gradient function derived from a model and an optional regularizer.
The scalar loss function being differentiated is simply:
mean(model’s per-example loss) + regularizer term
The returned gradient function support both unpadded batches, and padded batches with the mask feature keyed by client_datasets.EXAMPLE_MASK_KEY.
- fedjax.model_per_example_loss(model)[source]
Convenience function for constructing a per-example loss function from a model.
- fedjax.evaluate_average_loss(params, batches, rng, per_example_loss, regularizer=None)[source]
Evaluates the average per example loss over multiple batches.
- Parameters:
params (
Any
) – PyTree of model parameters to be evaluated.batches (
Iterable
[Mapping
[str
,Array
]]) – Multiple batches to compute and aggregate evaluation metrics over. Each batch can optional contain a feature keyed by client_datasets.MASK_KEY (see ClientDataset.padded_batch).rng (
Array
) – Initial PRNGKey for making per_example_loss calls.per_example_loss (
Callable
[[Any
,Mapping
[str
,Array
],Array
],Array
]) – Per example loss function.regularizer (
Optional
[Callable
[[Any
],Array
]]) – Optional regularizer function.
- Return type:
- Returns:
The average per example loss, plus the regularizer term when specified.
- class fedjax.ModelEvaluator(model)[source]
Evaluates model for each client dataset, either using global params, or per client params.
To evaluate a Model on a single dataset, use evaluate_model() instead.
- evaluate_global_params(params, clients)[source]
Evaluates batches from each client using global params.
- class fedjax.AverageLossEvaluator(per_example_loss, regularizer=None)[source]
Evaluates average loss for each client dataset, either using global params, or per client params.
The average loss is defined as the average per example loss, plus the regularizer term when specified. To evaluate average loss on a single dataset, use evaluate_average_loss() instead.
- evaluate_global_params(params, clients)[source]
Evaluates batches from each client using global params.
- fedjax.grad(per_example_loss, regularizer=None)[source]
A standard gradient function derived from per-example loss and an optional regularizer.
The scalar loss function being differentiated is simply:
mean(per-example loss) + regularizer term
The returned gradient function support both unpadded batches, and padded batches with the mask feature keyed by client_datasets.EXAMPLE_MASK_KEY.
- Parameters:
- Return type:
- Returns:
A function from (params, batch_example, rng) to gradients.