Skip to content

Conversation

@giovp
Copy link
Member

@giovp giovp commented Dec 17, 2025

This is an attempt to wrap the distributed dataloader handling directly into the new Loader class. It does two things:

  • handles torch distributed variables internally, and automatically splits chunks across shards
  • implements 2 alternative mode of handling uneven chunks:
    • drop_last_indices
    • pad_indices. In this case, it pads the indices for uneven chunks with observations from the "start" of the dataset. This introduces duplicates over the epoch.

Not sure if this is the type of ergonomics you are interested in supporting, but I personally find it useful. I'd be interested to hear if you have other approaches for distributed trainings.

Maybe working towards #56

@codecov
Copy link

codecov bot commented Dec 17, 2025

Codecov Report

❌ Patch coverage is 88.50575% with 10 lines in your changes missing coverage. Please review.
✅ Project coverage is 90.07%. Comparing base (7eb124a) to head (214c4d7).

Files with missing lines Patch % Lines
src/annbatch/loader.py 88.46% 9 Missing ⚠️
src/annbatch/distributed.py 88.88% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #100      +/-   ##
==========================================
- Coverage   90.90%   90.07%   -0.84%     
==========================================
  Files           5        6       +1     
  Lines         594      675      +81     
==========================================
+ Hits          540      608      +68     
- Misses         54       67      +13     
Files with missing lines Coverage Δ
src/annbatch/distributed.py 88.88% <88.88%> (ø)
src/annbatch/loader.py 92.21% <88.46%> (-0.88%) ⬇️

... and 1 file with indirect coverage changes

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@giovp
Copy link
Member Author

giovp commented Dec 17, 2025

Also, in case you have time, could you look at the "cache efficient strategy" in cellarium dataloader here: https://github.com/cellarium-ai/cellarium-ml/blob/58bc81b1e4ff51ceef51664bd99aed8af229b412/cellarium/ml/data/dadc_dataset.py#L193 and check whether you think this would make sense wrt the annbatch implementation? My understanding is that maybe it doesn't since annbatch uses a different logic to yield the batches.

@felix0097
Copy link
Collaborator

Hi @giovp,

thanks for your input! I have a few questions. As far as I understand, you're sharding the input dataset across several loaders, right?

@ilan-gold already implemented a similar strategy to support multi-process data loading for the torch dataLoader here. Does it make sense to somehow try to merge this?

Continuing on this thought, @selmanozleyen is currently working on a sampler API. Maybe the best approach is to somehow merge those efforts? E.g. use Selman's sampler API, then have a class TorchDistributedSampler that does the sharding for the user. This would make it very easy to extend to other distributed trainers as well etc in the future.

Let me know what you think!

@ilan-gold
Copy link
Collaborator

E.g. use Selman's sampler API, then have a class TorchDistributedSampler that does the sharding for the user. This would make it very easy to extend to other distributed trainers as well etc in the future.

We also spoke about maybe just a DistributedSampler class that takes as an argument a world/rank fetcher function so that it could be used with other frameworks. It looks like that is the only dependency ATM on torch, so should be easy enough to abstract out

@giovp
Copy link
Member Author

giovp commented Dec 17, 2025

Thanks both for comments!

@ilan-gold already implemented a similar strategy to support multi-process data loading for the torch dataLoader here. Does it make sense to somehow try to merge this?

This is at the worker level though, whereas the PR is concerned with the rank-level distribution. By merging you mean in the sense of refactoring? I'd be happy to take a look. But otherwise, they are two different levels of distributed streaming.

Maybe the best approach is to somehow merge those efforts? E.g. use Selman's sampler API, then have a class TorchDistributedSampler that does the sharding for the user.

mmh you can't pass a DistributedSampler, in fact you can't pass any Sampler to Iterable datasets, and my understanding is that the Loader class is effectively an iterable dataset (it both inherits from it, and implements an __iter__ method, not a __get_item__ method). What I thought the PR of @selmanozleyen was trying to do is implement the same sampling behaviour, but in the iterable case, which I think would be extremely useful btw, possibly #1 in my feature wishlist.

We also spoke about maybe just a DistributedSampler class that takes as an argument a world/rank fetcher function so that it could be used with other frameworks. It looks like that is the only dependency ATM on torch, so should be easy enough to abstract out

Other frameworks would be def cool, but then you get into the rabbit hole of what type of parallelization you want to support, in which framework. Definitely interesting, but also possibly a fair amount of work.

@ilan-gold
Copy link
Collaborator

By merging you mean in the sense of refactoring? I'd be happy to take a look. But otherwise, they are two different levels of distributed streaming.

Yes I think Felix's point is "mask out some portion of the dataset" is a common op so no sense duplicating - let's try to share a utility function or something

mmh you can't pass a DistributedSampler, in fact you can't pass any Sampler to Iterable datasets, and my understanding is that the Loader class is effectively an iterable dataset (it both inherits from it, and implements an iter method, not a get_item method). What I thought the PR of @selmanozleyen was trying to do is implement the same sampling behaviour, but in the iterable case, which I think would be extremely useful btw, possibly #1 in my feature wishlist.

Right that is what is Selman is doing. So here you would reimplement what you have as a DistributedSampler - it might make sense to look at his PR and work off of it/comment on it

Other frameworks would be def cool, but then you get into the rabbit hole of what type of parallelization you want to support, in which framework. Definitely interesting, but also possibly a fair amount of work.

Right, that's why I'd like to understand a bit more the space for this stuff. What is the minimal set of requirements we can build (either in a base sampler class or in this PR) to enable different distributed sampling? If this PR only implements one form, we should make that clear. Surely there has to be some sort of clear-ish documentation on this stuff somewhere?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants