add torch-augmentations
This commit is contained in:
parent
5dc5fd8f90
commit
542ab23d8a
|
|
@ -7,6 +7,7 @@ import pytorch_lightning as pl
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
from torch_audiomentations import Compose
|
||||
|
||||
from enhancer.data.fileprocessor import Fileprocessor
|
||||
from enhancer.utils import check_files
|
||||
|
|
@ -14,9 +15,6 @@ from enhancer.utils.config import Files
|
|||
from enhancer.utils.io import Audio
|
||||
from enhancer.utils.random import create_unique_rng
|
||||
|
||||
# from torch_audiomentations import Compose
|
||||
|
||||
|
||||
LARGE_NUM = 2147483647
|
||||
|
||||
|
||||
|
|
@ -66,7 +64,7 @@ class TaskDataset(pl.LightningDataModule):
|
|||
matching_function=None,
|
||||
batch_size=32,
|
||||
num_workers: Optional[int] = None,
|
||||
# augmentations: Optional[Compose] = None,
|
||||
augmentations: Optional[Compose] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
|
@ -86,7 +84,7 @@ class TaskDataset(pl.LightningDataModule):
|
|||
else:
|
||||
raise ValueError("valid_minutes must be greater than 0")
|
||||
|
||||
# self.augmentations = augmentations
|
||||
self.augmentations = augmentations
|
||||
|
||||
def setup(self, stage: Optional[str] = None):
|
||||
"""
|
||||
|
|
@ -178,7 +176,25 @@ class TaskDataset(pl.LightningDataModule):
|
|||
return metadata
|
||||
|
||||
def train_collatefn(self, batch):
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
output = {"clean": [], "noisy": []}
|
||||
for item in batch:
|
||||
output["clean"].append(item["clean"])
|
||||
output["noisy"].append(item["noisy"])
|
||||
|
||||
output["clean"] = torch.stack(output["clean"], dim=0)
|
||||
output["noisy"] = torch.stack(output["noisy"], dim=0)
|
||||
|
||||
if self.augmentations is not None:
|
||||
output["clean"] = self.augmentations(
|
||||
output["clean"], sample_rate=self.sampling_rate
|
||||
)
|
||||
self.augmentations.freeze_parameters()
|
||||
output["noisy"] = self.augmentations(
|
||||
output["noisy"], sample_rate=self.sampling_rate
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
@property
|
||||
def generator(self):
|
||||
|
|
@ -251,7 +267,7 @@ class EnhancerDataset(TaskDataset):
|
|||
matching_function=None,
|
||||
batch_size=32,
|
||||
num_workers: Optional[int] = None,
|
||||
# augmentations: Optional[Compose] = None,
|
||||
augmentations: Optional[Compose] = None,
|
||||
):
|
||||
|
||||
super().__init__(
|
||||
|
|
@ -264,7 +280,7 @@ class EnhancerDataset(TaskDataset):
|
|||
matching_function=matching_function,
|
||||
batch_size=batch_size,
|
||||
num_workers=num_workers,
|
||||
# augmentations=augmentations,
|
||||
augmentations=augmentations,
|
||||
)
|
||||
|
||||
self.sampling_rate = sampling_rate
|
||||
|
|
|
|||
Loading…
Reference in New Issue