You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Normalization and Augmentations are defined in the on_after_batch_transfer() function of the Datamodules to compute them on GPU like recommended from lightning. However, a downside of this is that you always have to pass the datamodule ino .fit and .test. Especially, for the latter, it can be convenient to test on separate dataloaders, however, those are then just "raw" dataloaders without normalization etc. being applied. Took me a minute to find that this was the reason for the funky test results. Currently, I am writing a custom collate_fn and set it to the dataloader that I am getting from a datamodule, however, it would be nice if this could be handled more easily. Open to hear thoughts about this, or suggestions for an easier ways to handle this than what I am doing at the moment.
Rationale
Sometimes I would like to test a model on different datasets and if a torchgeo datamodule is available, it is convenient to just retrieve a configured dataloder from an implemented datamodule.
Implementation
Maybe it could be possible to add a flag to return a dataloader with a collate function based on the on_afer_batch_transfer augmentation.
Alternatives
Currently I am doing something like this:
datamodule = ETCI2021DataModule(root=".", download=True, num_workers=4, batch_size=32)
datamodule.setup("fit")
def collate(batch: list[dict[str, torch.Tensor]]):
"""Collate fn to include augmentations."""
images = [item["image"] for item in batch]
labels = [item["label"] for item in batch]
inputs = torch.stack(images)
targets = torch.stack(labels)
return datamodule.on_after_batch_transfer({"image": inputs, "mask": targets})
val_dataloader = datamodule.val_dataloader()
val_dataloader.collate_fn = collate
The text was updated successfully, but these errors were encountered:
I can understand why you would want to be able to use a dataset if a data module doesn't exist, but why would you want to use a dataset if a data module does exist?
In order to do trainer.validate(model, dataloaders=datamodule.val_dataloader()) but not having to implement my own normalization scheme as a collate fn for every dataloader from a datamodule I want to use. So for example say I train one model and want to validate it on a bunch of datasets, then I could pass multiple dataloaders from different datasets or datamodules to trainer.validate()
If you pass a datamodule, it will only select the predefined validation loader and validate on that, but maybe I would like to validate on the train set and the validation set, for example when taking a pre-trained model and checking performance without training. Might also be relevant if you try something like cross validation, where you split your train/val sets. In my case, I am trying conformal prediction, where you need to take a subset of the validation set to create a separate calibration set and use the the model with that, so you need to control "which" split dataloader to apply validation to.
Summary
Normalization and Augmentations are defined in the
on_after_batch_transfer()
function of the Datamodules to compute them on GPU like recommended from lightning. However, a downside of this is that you always have to pass the datamodule ino.fit
and.test
. Especially, for the latter, it can be convenient to test on separate dataloaders, however, those are then just "raw" dataloaders without normalization etc. being applied. Took me a minute to find that this was the reason for the funky test results. Currently, I am writing a customcollate_fn
and set it to the dataloader that I am getting from a datamodule, however, it would be nice if this could be handled more easily. Open to hear thoughts about this, or suggestions for an easier ways to handle this than what I am doing at the moment.Rationale
Sometimes I would like to test a model on different datasets and if a torchgeo datamodule is available, it is convenient to just retrieve a configured dataloder from an implemented datamodule.
Implementation
Maybe it could be possible to add a flag to return a dataloader with a collate function based on the on_afer_batch_transfer augmentation.
Alternatives
Currently I am doing something like this:
The text was updated successfully, but these errors were encountered: