add augmentations

This commit is contained in:
shahules786 2022-10-25 12:43:54 +05:30
parent cdffe5c485
commit d1bafb3dc6
2 changed files with 9 additions and 1 deletions

View File

@ -7,6 +7,7 @@ from omegaconf import DictConfig, OmegaConf
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.loggers import MLFlowLogger from pytorch_lightning.loggers import MLFlowLogger
from torch.optim.lr_scheduler import ReduceLROnPlateau from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch_audiomentations import BandPassFilter, Compose, Shift
os.environ["HYDRA_FULL_ERROR"] = "1" os.environ["HYDRA_FULL_ERROR"] = "1"
JOB_ID = os.environ.get("SLURM_JOBID", "0") JOB_ID = os.environ.get("SLURM_JOBID", "0")
@ -25,8 +26,14 @@ def main(config: DictConfig):
) )
parameters = config.hyperparameters parameters = config.hyperparameters
apply_augmentations = Compose(
[
Shift(min_shift=0.0, max_shift=1.0, shift_unit="seconds", p=0.5),
BandPassFilter(p=0.5),
]
)
dataset = instantiate(config.dataset) dataset = instantiate(config.dataset, augmentations=apply_augmentations)
model = instantiate( model = instantiate(
config.model, config.model,
dataset=dataset, dataset=dataset,

View File

@ -211,6 +211,7 @@ class TaskDataset(pl.LightningDataModule):
batch_size=self.batch_size, batch_size=self.batch_size,
num_workers=self.num_workers, num_workers=self.num_workers,
generator=self.generator, generator=self.generator,
collate_fn=self.train_collatefn,
) )
def val_dataloader(self): def val_dataloader(self):