fedjax.datasets
fedjax datasets.
Federated cifar100. |
|
Federated EMNIST. |
|
Federated Shakespeare. |
|
Federated stackoverflow. |
CIFAR-100
Federated cifar100.
- fedjax.datasets.cifar100.load_data(mode='sqlite', cache_dir=None)[source]
Loads partially preprocessed cifar100 splits.
Features:
x: [N, 32, 32, 3] uint8 pixels.
y: [N] int32 labels in the range [0, 100).
Additional preprocessing (e.g. centering and normalizing) depends on whether a split is used for training or eval. For example,:
import functools from fedjax.datasets import cifar100 # Load partially preprocessed splits. train, test = cifar100.load_data() # Preprocessing for training. train_for_train = train.preprocess_batch( functools.partial(preprocess_batch, is_train=True)) # Preprocessing for eval. train_for_eval = train.preprocess_batch( functools.partial(preprocess_batch, is_train=False)) test = test.preprocess_batch( functools.partial(preprocess_batch, is_train=False))
Features after this preprocessing:
x: [N, 32, 32, 3] float32 preprocessed pixels.
y: [N] int32 labels in the range [0, 100).
Alternatively, you can apply the same preprocessing as TensorFlow Federated following tff.simulation.baselines.cifar100.create_image_classification_task. For example,:
from fedjax.datasets import cifar100 train, test = cifar100.load_data() train = train.preprocess_batch(preprocess_batch_tff) test = test.preprocess_batch(preprocess_batch_tff)
Features after this preprocessing:
x: [N, 24, 24, 3] float32 preprocessed pixels.
y: [N] int32 labels in the range [0, 100).
Note:
preprocess_batch
andpreprocess_batch_tff
are just convenience wrappers aroundpreprocess_image()
andpreprocess_image_tff()
, respectively, for use withfedjax.FederatedData.preprocess_batch()
.- Parameters:
mode (
str
) – ‘sqlite’.cache_dir (
Optional
[str
]) – Directory to cache files in ‘sqlite’ mode.
- Return type:
Tuple
[FederatedData
,FederatedData
]- Returns:
A (train, test) tuple of federated data.
- fedjax.datasets.cifar100.load_split(split, mode='sqlite', cache_dir=None)[source]
Loads a cifar100 split.
Features:
image: [N, 32, 32, 3] uint8 pixels.
coarse_label: [N] int64 coarse labels in the range [0, 20).
label: [N] int64 labels in the range [0, 100).
- Parameters:
split (
str
) – Name of the split. One of SPLITS.mode (
str
) – ‘sqlite’.cache_dir (
Optional
[str
]) – Directory to cache files in ‘sqlite’ mode.
- Return type:
- Returns:
FederatedData.
- fedjax.datasets.cifar100.preprocess_image(image, is_train)[source]
Augments and preprocesses CIFAR-100 images by cropping, flipping, and normalizing.
Preprocessing procedure and values taken from pytorch-cifar.
- Parameters:
image (
ndarray
) – [N, 32, 32, 3] uint8 pixels.is_train (
bool
) – Whether we are preprocessing for training or eval.
- Return type:
ndarray
- Returns:
Processed [N, 32, 32, 3] float32 pixels.
EMNIST
Federated EMNIST.
- fedjax.datasets.emnist.domain_id(client_id)[source]
Returns domain id for client id.
Domain ids are based on the NIST data source, where examples were collected from two sources: Bethesda high school (HIGH_SCHOOL) and Census Bureau in Suitland (CENSUS). For more details, see the NIST documentation.
- Parameters:
client_id (
bytes
) – Client id of the format[16-byte hex hash]:f[4-digit integer]_[2-digit integer]
orf[4-digit integer]_[2-digit integer]
.- Return type:
int
- Returns:
Domain id that is 0 (HIGH_SCHOOL) or 1 (CENSUS).
- fedjax.datasets.emnist.load_data(only_digits=False, mode='sqlite', cache_dir=None)[source]
Loads processed EMNIST train and test splits.
Features:
x: [N, 28, 28, 1] float32 flipped image pixels.
y: [N] int32 classification label.
domain_id: [N] int32 domain id (see
domain_id()
).
- Parameters:
only_digits (
bool
) – Whether to only load the digits data.mode (
str
) – ‘sqlite’.cache_dir (
Optional
[str
]) – Directory to cache files in ‘sqlite’ mode.
- Return type:
Tuple
[FederatedData
,FederatedData
]- Returns:
Train and test splits as FederatedData.
- fedjax.datasets.emnist.load_split(split, only_digits=False, mode='sqlite', cache_dir=None)[source]
Loads an unprocessed federated emnist split.
Features:
pixels: [N, 28, 28] float32 image pixels.
label: [N] int32 classification label.
- Parameters:
split (
str
) – Name of the split. One of SPLITS.only_digits (
bool
) – Whether to only load the digits data.mode (
str
) – ‘sqlite’.cache_dir (
Optional
[str
]) – Directory to cache files in ‘sqlite’ mode.
- Return type:
- Returns:
FederatedData.
Shakespeare
Federated Shakespeare.
- fedjax.datasets.shakespeare.load_data(sequence_length=80, mode='sqlite', cache_dir=None)[source]
Loads preprocessed shakespeare splits.
Preprocessing is done using
fedjax.FederatedData.preprocess_client()
andpreprocess_client()
.Features (M below is possibly different from N in load_split):
x: [M, sequence_length] int32 input labels, in the range of [0, shakespeare.VOCAB_SIZE)
y: [M, sequence_length] int32 output labels, in the range of [0, shakespeare.VOCAB_SIZE)
- Parameters:
sequence_length (
int
) – The fixed sequence length after preprocessing.mode (
str
) – ‘sqlite’.cache_dir (
Optional
[str
]) – Directory to cache files in ‘sqlite’ mode.
- Return type:
Tuple
[FederatedData
,FederatedData
]- Returns:
A (train, held_out, test) tuple of federated data.
- fedjax.datasets.shakespeare.load_split(split, mode='sqlite', cache_dir=None)[source]
Loads a shakespeare split.
Features:
snippets: [N] bytes array of snippet text.
- Parameters:
split (
str
) – Name of the split. One of SPLITS.mode (
str
) – ‘sqlite’.cache_dir (
Optional
[str
]) – Directory to cache files in ‘sqlite’ mode.
- Return type:
- Returns:
FederatedData.
- fedjax.datasets.shakespeare.preprocess_client(client_id, examples, sequence_length)[source]
Turns snippets into sequences of integer labels.
Features (M below is possibly different from N in load_split):
x: [M, sequence_length] int32 input labels, in the range of [0, shakespeare.VOCAB_SIZE)
y: [M, sequence_length] int32 output labels, in the range of [0, shakespeare.VOCAB_SIZE)
All snippets in a client dataset are first joined into a single sequence (with BOS/EOS added), and then split into pairs of sequence_length chunks for language model training. For example, with sequence_length=3, [b’ABCD’, b’E’] becomes:
Input sequences: [[BOS, A, B], [C, D, EOS], [BOS, E, PAD]] Output seqeunces: [[A, B, C], [D, EOS, BOS], [E, EOS, PAD]]
Note: This is not equivalent to the TensorFlow Federated text generation tutorial (The processing logic there loses ~1/sequence_length portion of the tokens).
- Parameters:
client_id (
bytes
) – Not used.examples (
Mapping
[str
,ndarray
]) – Unprocessed examples (e.g. from load_split()).sequence_length (
int
) – The fixed sequence length after preprocessing.
- Return type:
Mapping
[str
,ndarray
]- Returns:
Processed examples.
Stack Overflow
Federated stackoverflow.
- fedjax.datasets.stackoverflow.load_data(mode='sqlite', cache_dir=None)[source]
Loads partially preprocessed stackoverflow splits.
Features:
domain_id: [N] int32 domain id derived from type (question = 0; answer = 1).
tokens: [N] bytes array. Space separated list of tokens.
To convert tokens into padded/truncated integer labels, use a StackoverflowTokenizer. For example,:
from fedjax.core.datasets import stackoverflow # Load partially preprocessed splits. train, held_out, test = stackoverflow.load_data() # Apply tokenizer during batching. tokenizer = stackoverflow.StackoverflowTokenizer() train_max_length, eval_max_length = 20, 30 train_for_train = train.preprocess_batch( tokenizer.as_preprocess_batch(train_max_length)) train_for_eval = train.preprocess_batch( tokenizer.as_preprocess_batch(eval_max_length)) held_out = held_out.preprocess_batch( tokenizer.as_preprocess_batch(eval_max_length)) test = test.preprocess_batch( tokenizer.as_preprocess_batch(eval_max_length))
Features after tokenization:
domain_id: Same as before.
x: [N, max_length] int32 array of padded/truncated input labels.
y: [N, max_length] int32 array of padded/truncated output labels.
- Parameters:
mode (
str
) – ‘sqlite’.cache_dir (
Optional
[str
]) – Directory to cache files in ‘sqlite’ mode.
- Return type:
Tuple
[FederatedData
,FederatedData
,FederatedData
]- Returns:
A (train, held_out, test) tuple of federated data.
- fedjax.datasets.stackoverflow.load_split(split, mode='sqlite', cache_dir=None)[source]
Loads a stackoverflow split.
All bytes arrays are stored with dtype=np.object.
Features:
creation_date: [N] bytes array. Textual timestamp, e.g. b’2018-02-28 19:06:18.34 UTC’.
title: [N] bytes array. The title of a post.
score: [N] int64 array. The score of a post.
tags: [N] bytes array. ‘|’ separated list of tags, e.g. b’mysql|join’.
tokens: [N] bytes array. Space separated list of tokens.
type: [N] bytes array. Either b’question’ or b’answer’.
- Parameters:
split (
str
) – Name of the split. One of SPLITS.mode (
str
) – ‘sqlite’.cache_dir (
Optional
[str
]) – Directory to cache files in ‘sqlite’ mode.
- Return type:
- Returns:
FederatedData.
- class fedjax.datasets.stackoverflow.StackoverflowTokenizer(vocab=None, default_vocab_size=10000, num_oov_buckets=1)[source]
Tokenizer for the tokens feature in stackoverflow.
See
load_data()
for examples.- __init__(vocab=None, default_vocab_size=10000, num_oov_buckets=1)[source]
Initializes a tokenizer.
- Parameters:
vocab (
Optional
[List
[str
]]) – Optional vocabulary. If specified, default_vocab_size is ignored. If None, default_vocab_size is used to load the standard vocabulary. This vocabulary should NOT have special tokens PAD, EOS, BOS, and OOV. The special tokens are added and handled automatically by the tokenizer. The preprocessed examples will have vocabulary size len(vocab) + 3 + num_oov_buckets.default_vocab_size (
Optional
[int
]) – Number of words in the default vocabulary. This is only used when vocab is not specified. The preprocessed examples will have vocabulary size default_vocab_size + 3 + num_oov_buckets with 3 special labels: 0 (PAD), 1 (BOS), 2 (EOS), and num_oov_buckets OOV labels starting at default_vocab_size + 3.num_oov_buckets (
int
) – Number of out of vocabulary buckets.