add augmentations
This commit is contained in:
parent
cdffe5c485
commit
d1bafb3dc6
|
|
@ -7,6 +7,7 @@ from omegaconf import DictConfig, OmegaConf
|
||||||
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
|
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
|
||||||
from pytorch_lightning.loggers import MLFlowLogger
|
from pytorch_lightning.loggers import MLFlowLogger
|
||||||
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
||||||
|
from torch_audiomentations import BandPassFilter, Compose, Shift
|
||||||
|
|
||||||
os.environ["HYDRA_FULL_ERROR"] = "1"
|
os.environ["HYDRA_FULL_ERROR"] = "1"
|
||||||
JOB_ID = os.environ.get("SLURM_JOBID", "0")
|
JOB_ID = os.environ.get("SLURM_JOBID", "0")
|
||||||
|
|
@ -25,8 +26,14 @@ def main(config: DictConfig):
|
||||||
)
|
)
|
||||||
|
|
||||||
parameters = config.hyperparameters
|
parameters = config.hyperparameters
|
||||||
|
apply_augmentations = Compose(
|
||||||
|
[
|
||||||
|
Shift(min_shift=0.0, max_shift=1.0, shift_unit="seconds", p=0.5),
|
||||||
|
BandPassFilter(p=0.5),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
dataset = instantiate(config.dataset)
|
dataset = instantiate(config.dataset, augmentations=apply_augmentations)
|
||||||
model = instantiate(
|
model = instantiate(
|
||||||
config.model,
|
config.model,
|
||||||
dataset=dataset,
|
dataset=dataset,
|
||||||
|
|
|
||||||
|
|
@ -211,6 +211,7 @@ class TaskDataset(pl.LightningDataModule):
|
||||||
batch_size=self.batch_size,
|
batch_size=self.batch_size,
|
||||||
num_workers=self.num_workers,
|
num_workers=self.num_workers,
|
||||||
generator=self.generator,
|
generator=self.generator,
|
||||||
|
collate_fn=self.train_collatefn,
|
||||||
)
|
)
|
||||||
|
|
||||||
def val_dataloader(self):
|
def val_dataloader(self):
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue