rmv augmentations
This commit is contained in:
parent
c51dea6885
commit
47bbee2c32
|
|
@ -11,7 +11,8 @@ from pytorch_lightning.callbacks import (
|
||||||
)
|
)
|
||||||
from pytorch_lightning.loggers import MLFlowLogger
|
from pytorch_lightning.loggers import MLFlowLogger
|
||||||
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
||||||
from torch_audiomentations import Compose, Shift
|
|
||||||
|
# from torch_audiomentations import Compose, Shift
|
||||||
|
|
||||||
os.environ["HYDRA_FULL_ERROR"] = "1"
|
os.environ["HYDRA_FULL_ERROR"] = "1"
|
||||||
JOB_ID = os.environ.get("SLURM_JOBID", "0")
|
JOB_ID = os.environ.get("SLURM_JOBID", "0")
|
||||||
|
|
@ -30,13 +31,13 @@ def main(config: DictConfig):
|
||||||
)
|
)
|
||||||
|
|
||||||
parameters = config.hyperparameters
|
parameters = config.hyperparameters
|
||||||
apply_augmentations = Compose(
|
# apply_augmentations = Compose(
|
||||||
[
|
# [
|
||||||
Shift(min_shift=0.5, max_shift=1.0, shift_unit="seconds", p=0.5),
|
# Shift(min_shift=0.5, max_shift=1.0, shift_unit="seconds", p=0.5),
|
||||||
]
|
# ]
|
||||||
)
|
# )
|
||||||
|
|
||||||
dataset = instantiate(config.dataset, augmentations=apply_augmentations)
|
dataset = instantiate(config.dataset, augmentations=None)
|
||||||
model = instantiate(
|
model = instantiate(
|
||||||
config.model,
|
config.model,
|
||||||
dataset=dataset,
|
dataset=dataset,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue