add torch-augmentations
This commit is contained in:
parent
5dc5fd8f90
commit
542ab23d8a
|
|
@ -7,6 +7,7 @@ import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch.utils.data import DataLoader, Dataset
|
from torch.utils.data import DataLoader, Dataset
|
||||||
|
from torch_audiomentations import Compose
|
||||||
|
|
||||||
from enhancer.data.fileprocessor import Fileprocessor
|
from enhancer.data.fileprocessor import Fileprocessor
|
||||||
from enhancer.utils import check_files
|
from enhancer.utils import check_files
|
||||||
|
|
@ -14,9 +15,6 @@ from enhancer.utils.config import Files
|
||||||
from enhancer.utils.io import Audio
|
from enhancer.utils.io import Audio
|
||||||
from enhancer.utils.random import create_unique_rng
|
from enhancer.utils.random import create_unique_rng
|
||||||
|
|
||||||
# from torch_audiomentations import Compose
|
|
||||||
|
|
||||||
|
|
||||||
LARGE_NUM = 2147483647
|
LARGE_NUM = 2147483647
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -66,7 +64,7 @@ class TaskDataset(pl.LightningDataModule):
|
||||||
matching_function=None,
|
matching_function=None,
|
||||||
batch_size=32,
|
batch_size=32,
|
||||||
num_workers: Optional[int] = None,
|
num_workers: Optional[int] = None,
|
||||||
# augmentations: Optional[Compose] = None,
|
augmentations: Optional[Compose] = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
|
@ -86,7 +84,7 @@ class TaskDataset(pl.LightningDataModule):
|
||||||
else:
|
else:
|
||||||
raise ValueError("valid_minutes must be greater than 0")
|
raise ValueError("valid_minutes must be greater than 0")
|
||||||
|
|
||||||
# self.augmentations = augmentations
|
self.augmentations = augmentations
|
||||||
|
|
||||||
def setup(self, stage: Optional[str] = None):
|
def setup(self, stage: Optional[str] = None):
|
||||||
"""
|
"""
|
||||||
|
|
@ -178,7 +176,25 @@ class TaskDataset(pl.LightningDataModule):
|
||||||
return metadata
|
return metadata
|
||||||
|
|
||||||
def train_collatefn(self, batch):
|
def train_collatefn(self, batch):
|
||||||
raise NotImplementedError("Not implemented")
|
|
||||||
|
output = {"clean": [], "noisy": []}
|
||||||
|
for item in batch:
|
||||||
|
output["clean"].append(item["clean"])
|
||||||
|
output["noisy"].append(item["noisy"])
|
||||||
|
|
||||||
|
output["clean"] = torch.stack(output["clean"], dim=0)
|
||||||
|
output["noisy"] = torch.stack(output["noisy"], dim=0)
|
||||||
|
|
||||||
|
if self.augmentations is not None:
|
||||||
|
output["clean"] = self.augmentations(
|
||||||
|
output["clean"], sample_rate=self.sampling_rate
|
||||||
|
)
|
||||||
|
self.augmentations.freeze_parameters()
|
||||||
|
output["noisy"] = self.augmentations(
|
||||||
|
output["noisy"], sample_rate=self.sampling_rate
|
||||||
|
)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def generator(self):
|
def generator(self):
|
||||||
|
|
@ -251,7 +267,7 @@ class EnhancerDataset(TaskDataset):
|
||||||
matching_function=None,
|
matching_function=None,
|
||||||
batch_size=32,
|
batch_size=32,
|
||||||
num_workers: Optional[int] = None,
|
num_workers: Optional[int] = None,
|
||||||
# augmentations: Optional[Compose] = None,
|
augmentations: Optional[Compose] = None,
|
||||||
):
|
):
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
|
|
@ -264,7 +280,7 @@ class EnhancerDataset(TaskDataset):
|
||||||
matching_function=matching_function,
|
matching_function=matching_function,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
num_workers=num_workers,
|
num_workers=num_workers,
|
||||||
# augmentations=augmentations,
|
augmentations=augmentations,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.sampling_rate = sampling_rate
|
self.sampling_rate = sampling_rate
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue