arrayloaders.io.DaskDataset

class arrayloaders.io.DaskDataset(adata, label_column, n_chunks=8, shuffle=True, dask_scheduler='threads', n_workers=None)

Bases: IterableDataset

Dask-based IterableDataset for loading AnnData objects in chunks.

Parameters:
  • adata (AnnData) – The AnnData object to yield samples from.

  • label_column (str) – The name of the column in adata.obs that contains the labels.

  • n_chunks (int, default: 8) – Number of chunks of the underlying dask.array to load at a time. Loading more chunks at a time can improve performance and randomness, but increases memory usage. Defaults to 8.

  • shuffle (bool, default: True) – Whether to yield samples in a random order. Defaults to True.

  • dask_scheduler (Literal['synchronous', 'threads'], default: 'threads') – The Dask scheduler to use for parallel computation. “synchronous” for single-threaded execution, “threads” for multithreaded execution. Defaults to “threads”.

  • n_workers (int | None, default: None) – Number of Dask workers to use. If None, the number of workers is determined by Dask.

Examples

>>> from arrayloaders.io.dask_loader import DaskDataset, read_lazy_store
>>> from torch.utils.data import DataLoader
>>> label_column = "y"
>>> adata = read_lazy_store("path/to/zarr/store", obs_columns=[label_column])
>>> dataset = DaskDataset(adata, label_column=label_column, n_chunks=8, shuffle=True)
>>> dataloader = DataLoader(dataset, batch_size=2048, num_workers=4, drop_last=True)
>>> for batch in dataloader:
...     x, y = batch
...     # Process the batch