BatchSampler

BatchSampler#

class pinnx.utils.sampler.BatchSampler(num_samples, shuffle=True)[source]#

Samples a mini-batch of indices.

The indices are repeated indefinitely. Has the same effect as:

indices = tf.data.Dataset.range(num_samples)
indices = indices.repeat().shuffle(num_samples).batch(batch_size)
iterator = iter(indices)
batch_indices = iterator.get_next()

However, tf.data.Dataset.__iter__() is only supported inside of tf.function or when eager execution is enabled. tf.data.Dataset.make_one_shot_iterator() supports graph mode, but is too slow.

This class is not implemented as a Python Iterator, so that it can support dynamic batch size.

Parameters:
  • num_samples (int) – The number of samples.

  • shuffle (bool) – Set to True to have the indices reshuffled at every epoch.

get_next(batch_size)[source]#

Returns the indices of the next batch.

Parameters:

batch_size (int) – The number of elements to combine in a single batch.