add torch-augmentations

This commit is contained in:
shahules786 2022-10-24 21:50:30 +05:30
parent 5dc5fd8f90
commit 542ab23d8a
1 changed files with 24 additions and 8 deletions

View File

@ -7,6 +7,7 @@ import pytorch_lightning as pl
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torch_audiomentations import Compose
from enhancer.data.fileprocessor import Fileprocessor
from enhancer.utils import check_files
@ -14,9 +15,6 @@ from enhancer.utils.config import Files
from enhancer.utils.io import Audio
from enhancer.utils.random import create_unique_rng
# from torch_audiomentations import Compose
LARGE_NUM = 2147483647
@ -66,7 +64,7 @@ class TaskDataset(pl.LightningDataModule):
matching_function=None,
batch_size=32,
num_workers: Optional[int] = None,
# augmentations: Optional[Compose] = None,
augmentations: Optional[Compose] = None,
):
super().__init__()
@ -86,7 +84,7 @@ class TaskDataset(pl.LightningDataModule):
else:
raise ValueError("valid_minutes must be greater than 0")
# self.augmentations = augmentations
self.augmentations = augmentations
def setup(self, stage: Optional[str] = None):
"""
@ -178,7 +176,25 @@ class TaskDataset(pl.LightningDataModule):
return metadata
def train_collatefn(self, batch):
raise NotImplementedError("Not implemented")
output = {"clean": [], "noisy": []}
for item in batch:
output["clean"].append(item["clean"])
output["noisy"].append(item["noisy"])
output["clean"] = torch.stack(output["clean"], dim=0)
output["noisy"] = torch.stack(output["noisy"], dim=0)
if self.augmentations is not None:
output["clean"] = self.augmentations(
output["clean"], sample_rate=self.sampling_rate
)
self.augmentations.freeze_parameters()
output["noisy"] = self.augmentations(
output["noisy"], sample_rate=self.sampling_rate
)
return output
@property
def generator(self):
@ -251,7 +267,7 @@ class EnhancerDataset(TaskDataset):
matching_function=None,
batch_size=32,
num_workers: Optional[int] = None,
# augmentations: Optional[Compose] = None,
augmentations: Optional[Compose] = None,
):
super().__init__(
@ -264,7 +280,7 @@ class EnhancerDataset(TaskDataset):
matching_function=matching_function,
batch_size=batch_size,
num_workers=num_workers,
# augmentations=augmentations,
augmentations=augmentations,
)
self.sampling_rate = sampling_rate