commit
ba271f8a2a
|
|
@ -1,4 +1,5 @@
|
||||||
##local
|
#local
|
||||||
|
cli/train_config/dataset/Vctk_local.yaml
|
||||||
.DS_Store
|
.DS_Store
|
||||||
outputs/
|
outputs/
|
||||||
datasets/
|
datasets/
|
||||||
|
|
|
||||||
|
|
@ -1,17 +1,17 @@
|
||||||
|
import os
|
||||||
import hydra
|
import hydra
|
||||||
from hydra.utils import instantiate
|
from hydra.utils import instantiate
|
||||||
from omegaconf import DictConfig
|
from omegaconf import DictConfig
|
||||||
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
|
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
|
||||||
from pytorch_lightning.loggers import MLFlowLogger
|
from pytorch_lightning.loggers import MLFlowLogger
|
||||||
|
|
||||||
from enhancer.data.dataset import EnhancerDataset
|
|
||||||
|
|
||||||
@hydra.main(config_path="train_config",config_name="config")
|
@hydra.main(config_path="train_config",config_name="config")
|
||||||
def main(config: DictConfig):
|
def main(config: DictConfig):
|
||||||
|
|
||||||
callbacks = []
|
callbacks = []
|
||||||
logger = MLFlowLogger(experiment_name=config.mlflow.experiment_name,
|
logger = MLFlowLogger(experiment_name=config.mlflow.experiment_name,
|
||||||
run_name=config.mlflow.run_name)
|
run_name=config.mlflow.run_name, tags={"JOB_ID":os.environ.get("SLURM_JOBID")})
|
||||||
|
|
||||||
|
|
||||||
parameters = config.hyperparameters
|
parameters = config.hyperparameters
|
||||||
|
|
@ -21,12 +21,12 @@ def main(config: DictConfig):
|
||||||
loss=parameters.get("loss"), metric = parameters.get("metric"))
|
loss=parameters.get("loss"), metric = parameters.get("metric"))
|
||||||
|
|
||||||
checkpoint = ModelCheckpoint(
|
checkpoint = ModelCheckpoint(
|
||||||
dirpath="",filename="model",monitor=parameters.get("loss"),verbose=False,
|
dirpath="",filename="model",monitor="valid_loss",verbose=False,
|
||||||
mode="min",every_n_epochs=1
|
mode="min",every_n_epochs=1
|
||||||
)
|
)
|
||||||
callbacks.append(checkpoint)
|
callbacks.append(checkpoint)
|
||||||
early_stopping = EarlyStopping(
|
early_stopping = EarlyStopping(
|
||||||
monitor=parameters.get("loss"),
|
monitor="valid_loss",
|
||||||
mode="min",
|
mode="min",
|
||||||
min_delta=0.0,
|
min_delta=0.0,
|
||||||
patience=100,
|
patience=100,
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
defaults:
|
defaults:
|
||||||
- model : Demucs
|
- model : WaveUnet
|
||||||
- dataset : Vctk
|
- dataset : Vctk
|
||||||
- optimizer : Adam
|
- optimizer : Adam
|
||||||
- hyperparameters : default
|
- hyperparameters : default
|
||||||
- trainer : fastrun_dev
|
- trainer : default
|
||||||
- mlflow : experiment
|
- mlflow : experiment
|
||||||
|
|
@ -2,8 +2,8 @@ _target_: enhancer.data.dataset.EnhancerDataset
|
||||||
name : vctk
|
name : vctk
|
||||||
root_dir : /scratch/c.sistc3/DS_10283_2791
|
root_dir : /scratch/c.sistc3/DS_10283_2791
|
||||||
duration : 1.0
|
duration : 1.0
|
||||||
sampling_rate: 48000
|
sampling_rate: 16000
|
||||||
batch_size: 32
|
batch_size: 8
|
||||||
|
|
||||||
files:
|
files:
|
||||||
train_clean : clean_trainset_56spk_wav
|
train_clean : clean_trainset_56spk_wav
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
loss : mse
|
loss : mse
|
||||||
metric : mae
|
metric : mae
|
||||||
lr : 0.001
|
lr : 0.0001
|
||||||
num_epochs : 10
|
num_epochs : 100
|
||||||
|
|
|
||||||
|
|
@ -1,2 +1,2 @@
|
||||||
experiment_name : "myexp"
|
experiment_name : shahules/enhancer
|
||||||
run_name : "myrun"
|
run_name : baseline
|
||||||
|
|
@ -1,16 +1,15 @@
|
||||||
# @package _group_
|
|
||||||
_target_: pytorch_lightning.Trainer
|
_target_: pytorch_lightning.Trainer
|
||||||
accelerator: auto
|
accelerator: auto
|
||||||
accumulate_grad_batches: 1
|
accumulate_grad_batches: 1
|
||||||
amp_backend: native
|
amp_backend: native
|
||||||
auto_lr_find: False
|
auto_lr_find: True
|
||||||
auto_scale_batch_size: False
|
auto_scale_batch_size: False
|
||||||
auto_select_gpus: True
|
auto_select_gpus: True
|
||||||
benchmark: False
|
benchmark: False
|
||||||
check_val_every_n_epoch: 1
|
check_val_every_n_epoch: 1
|
||||||
detect_anomaly: False
|
detect_anomaly: False
|
||||||
deterministic: False
|
deterministic: False
|
||||||
devices: auto
|
devices: -1
|
||||||
enable_checkpointing: True
|
enable_checkpointing: True
|
||||||
enable_model_summary: True
|
enable_model_summary: True
|
||||||
enable_progress_bar: True
|
enable_progress_bar: True
|
||||||
|
|
@ -23,8 +22,8 @@ limit_predict_batches: 1.0
|
||||||
limit_test_batches: 1.0
|
limit_test_batches: 1.0
|
||||||
limit_train_batches: 1.0
|
limit_train_batches: 1.0
|
||||||
limit_val_batches: 1.0
|
limit_val_batches: 1.0
|
||||||
log_every_n_steps: 50
|
log_every_n_steps: 10
|
||||||
max_epochs: 1000
|
max_epochs: 100
|
||||||
max_steps: null
|
max_steps: null
|
||||||
max_time: null
|
max_time: null
|
||||||
min_epochs: 1
|
min_epochs: 1
|
||||||
|
|
@ -1,3 +1,2 @@
|
||||||
# @package _group_
|
|
||||||
_target_: pytorch_lightning.Trainer
|
_target_: pytorch_lightning.Trainer
|
||||||
fast_dev_run: True
|
fast_dev_run: True
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,4 @@
|
||||||
|
import multiprocessing
|
||||||
from dataclasses import dataclass
|
|
||||||
import glob
|
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
|
|
@ -46,7 +44,8 @@ class TaskDataset(pl.LightningDataModule):
|
||||||
duration:float=1.0,
|
duration:float=1.0,
|
||||||
sampling_rate:int=48000,
|
sampling_rate:int=48000,
|
||||||
matching_function = None,
|
matching_function = None,
|
||||||
batch_size=32):
|
batch_size=32,
|
||||||
|
num_workers:Optional[int]=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.name = name
|
self.name = name
|
||||||
|
|
@ -56,6 +55,9 @@ class TaskDataset(pl.LightningDataModule):
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.matching_function = matching_function
|
self.matching_function = matching_function
|
||||||
self._validation = []
|
self._validation = []
|
||||||
|
if num_workers is None:
|
||||||
|
num_workers = multiprocessing.cpu_count()//2
|
||||||
|
self.num_workers = num_workers
|
||||||
|
|
||||||
def setup(self, stage: Optional[str] = None):
|
def setup(self, stage: Optional[str] = None):
|
||||||
|
|
||||||
|
|
@ -64,15 +66,13 @@ class TaskDataset(pl.LightningDataModule):
|
||||||
train_clean = os.path.join(self.root_dir,self.files.train_clean)
|
train_clean = os.path.join(self.root_dir,self.files.train_clean)
|
||||||
train_noisy = os.path.join(self.root_dir,self.files.train_noisy)
|
train_noisy = os.path.join(self.root_dir,self.files.train_noisy)
|
||||||
fp = Fileprocessor.from_name(self.name,train_clean,
|
fp = Fileprocessor.from_name(self.name,train_clean,
|
||||||
train_noisy,self.sampling_rate,
|
train_noisy, self.matching_function)
|
||||||
self.matching_function)
|
|
||||||
self.train_data = fp.prepare_matching_dict()
|
self.train_data = fp.prepare_matching_dict()
|
||||||
|
|
||||||
val_clean = os.path.join(self.root_dir,self.files.test_clean)
|
val_clean = os.path.join(self.root_dir,self.files.test_clean)
|
||||||
val_noisy = os.path.join(self.root_dir,self.files.test_noisy)
|
val_noisy = os.path.join(self.root_dir,self.files.test_noisy)
|
||||||
fp = Fileprocessor.from_name(self.name,val_clean,
|
fp = Fileprocessor.from_name(self.name,val_clean,
|
||||||
val_noisy,self.sampling_rate,
|
val_noisy, self.matching_function)
|
||||||
self.matching_function)
|
|
||||||
val_data = fp.prepare_matching_dict()
|
val_data = fp.prepare_matching_dict()
|
||||||
|
|
||||||
for item in val_data:
|
for item in val_data:
|
||||||
|
|
@ -85,10 +85,10 @@ class TaskDataset(pl.LightningDataModule):
|
||||||
self._validation.append(({"clean":clean,"noisy":noisy},
|
self._validation.append(({"clean":clean,"noisy":noisy},
|
||||||
start_time))
|
start_time))
|
||||||
def train_dataloader(self):
|
def train_dataloader(self):
|
||||||
return DataLoader(TrainDataset(self), batch_size = self.batch_size,num_workers=2)
|
return DataLoader(TrainDataset(self), batch_size = self.batch_size,num_workers=self.num_workers)
|
||||||
|
|
||||||
def val_dataloader(self):
|
def val_dataloader(self):
|
||||||
return DataLoader(ValidDataset(self), batch_size = self.batch_size,num_workers=2)
|
return DataLoader(ValidDataset(self), batch_size = self.batch_size,num_workers=self.num_workers)
|
||||||
|
|
||||||
class EnhancerDataset(TaskDataset):
|
class EnhancerDataset(TaskDataset):
|
||||||
"""Dataset object for creating clean-noisy speech enhancement datasets"""
|
"""Dataset object for creating clean-noisy speech enhancement datasets"""
|
||||||
|
|
@ -101,7 +101,8 @@ class EnhancerDataset(TaskDataset):
|
||||||
duration=1.0,
|
duration=1.0,
|
||||||
sampling_rate=48000,
|
sampling_rate=48000,
|
||||||
matching_function=None,
|
matching_function=None,
|
||||||
batch_size=32):
|
batch_size=32,
|
||||||
|
num_workers:Optional[int]=None):
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
name=name,
|
name=name,
|
||||||
|
|
@ -110,7 +111,8 @@ class EnhancerDataset(TaskDataset):
|
||||||
sampling_rate=sampling_rate,
|
sampling_rate=sampling_rate,
|
||||||
duration=duration,
|
duration=duration,
|
||||||
matching_function = matching_function,
|
matching_function = matching_function,
|
||||||
batch_size=batch_size
|
batch_size=batch_size,
|
||||||
|
num_workers = num_workers,
|
||||||
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,12 +1,13 @@
|
||||||
import glob
|
import glob
|
||||||
import os
|
import os
|
||||||
|
from re import S
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from scipy.io import wavfile
|
from scipy.io import wavfile
|
||||||
|
|
||||||
class ProcessorFunctions:
|
class ProcessorFunctions:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def match_vtck(clean_path,noisy_path,sr):
|
def match_vtck(clean_path,noisy_path):
|
||||||
|
|
||||||
matching_wavfiles = list()
|
matching_wavfiles = list()
|
||||||
clean_filenames = [file.split('/')[-1] for file in glob.glob(os.path.join(clean_path,"*.wav"))]
|
clean_filenames = [file.split('/')[-1] for file in glob.glob(os.path.join(clean_path,"*.wav"))]
|
||||||
|
|
@ -18,16 +19,15 @@ class ProcessorFunctions:
|
||||||
sr_clean, clean_file = wavfile.read(os.path.join(clean_path,file_name))
|
sr_clean, clean_file = wavfile.read(os.path.join(clean_path,file_name))
|
||||||
sr_noisy, noisy_file = wavfile.read(os.path.join(noisy_path,file_name))
|
sr_noisy, noisy_file = wavfile.read(os.path.join(noisy_path,file_name))
|
||||||
if ((clean_file.shape[-1]==noisy_file.shape[-1]) and
|
if ((clean_file.shape[-1]==noisy_file.shape[-1]) and
|
||||||
(sr_clean==sr) and
|
(sr_clean==sr_noisy)):
|
||||||
(sr_noisy==sr)):
|
|
||||||
matching_wavfiles.append(
|
matching_wavfiles.append(
|
||||||
{"clean":os.path.join(clean_path,file_name),"noisy":os.path.join(noisy_path,file_name),
|
{"clean":os.path.join(clean_path,file_name),"noisy":os.path.join(noisy_path,file_name),
|
||||||
"duration":clean_file.shape[-1]/sr}
|
"duration":clean_file.shape[-1]/sr_clean}
|
||||||
)
|
)
|
||||||
return matching_wavfiles
|
return matching_wavfiles
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def match_dns2020(clean_path,noisy_path,sr):
|
def match_dns2020(clean_path,noisy_path):
|
||||||
|
|
||||||
matching_wavfiles = dict()
|
matching_wavfiles = dict()
|
||||||
clean_filenames = [file.split('/')[-1] for file in glob.glob(os.path.join(clean_path,"*.wav"))]
|
clean_filenames = [file.split('/')[-1] for file in glob.glob(os.path.join(clean_path,"*.wav"))]
|
||||||
|
|
@ -38,11 +38,10 @@ class ProcessorFunctions:
|
||||||
sr_clean, clean_file = wavfile.read(os.path.join(clean_path,clean_file))
|
sr_clean, clean_file = wavfile.read(os.path.join(clean_path,clean_file))
|
||||||
sr_noisy, noisy_file = wavfile.read(noisy_file)
|
sr_noisy, noisy_file = wavfile.read(noisy_file)
|
||||||
if ((clean_file.shape[-1]==noisy_file.shape[-1]) and
|
if ((clean_file.shape[-1]==noisy_file.shape[-1]) and
|
||||||
(sr_clean==sr) and
|
(sr_clean==sr_noisy)):
|
||||||
(sr_noisy==sr)):
|
|
||||||
matching_wavfiles.update(
|
matching_wavfiles.update(
|
||||||
{"clean":os.path.join(clean_path,clean_file),"noisy":noisy_file,
|
{"clean":os.path.join(clean_path,clean_file),"noisy":noisy_file,
|
||||||
"duration":clean_file.shape[-1]/sr}
|
"duration":clean_file.shape[-1]/sr_clean}
|
||||||
)
|
)
|
||||||
|
|
||||||
return matching_wavfiles
|
return matching_wavfiles
|
||||||
|
|
@ -54,12 +53,10 @@ class Fileprocessor:
|
||||||
self,
|
self,
|
||||||
clean_dir,
|
clean_dir,
|
||||||
noisy_dir,
|
noisy_dir,
|
||||||
sr = 16000,
|
|
||||||
matching_function = None
|
matching_function = None
|
||||||
):
|
):
|
||||||
self.clean_dir = clean_dir
|
self.clean_dir = clean_dir
|
||||||
self.noisy_dir = noisy_dir
|
self.noisy_dir = noisy_dir
|
||||||
self.sr = sr
|
|
||||||
self.matching_function = matching_function
|
self.matching_function = matching_function
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
@ -67,23 +64,22 @@ class Fileprocessor:
|
||||||
name:str,
|
name:str,
|
||||||
clean_dir,
|
clean_dir,
|
||||||
noisy_dir,
|
noisy_dir,
|
||||||
sr,
|
|
||||||
matching_function=None
|
matching_function=None
|
||||||
):
|
):
|
||||||
|
|
||||||
if name.lower() == "vctk":
|
if name.lower() == "vctk":
|
||||||
return cls(clean_dir,noisy_dir,sr, ProcessorFunctions.match_vtck)
|
return cls(clean_dir,noisy_dir, ProcessorFunctions.match_vtck)
|
||||||
elif name.lower() == "dns-2020":
|
elif name.lower() == "dns-2020":
|
||||||
return cls(clean_dir,noisy_dir,sr, ProcessorFunctions.match_dns2020)
|
return cls(clean_dir,noisy_dir, ProcessorFunctions.match_dns2020)
|
||||||
else:
|
else:
|
||||||
return cls(clean_dir,noisy_dir,sr, matching_function)
|
return cls(clean_dir,noisy_dir, matching_function)
|
||||||
|
|
||||||
def prepare_matching_dict(self):
|
def prepare_matching_dict(self):
|
||||||
|
|
||||||
if self.matching_function is None:
|
if self.matching_function is None:
|
||||||
raise ValueError("Not a valid matching function")
|
raise ValueError("Not a valid matching function")
|
||||||
|
|
||||||
return self.matching_function(self.clean_dir,self.noisy_dir,self.sr)
|
return self.matching_function(self.clean_dir,self.noisy_dir)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,6 @@ from pathlib import Path
|
||||||
from librosa import load as load_audio
|
from librosa import load as load_audio
|
||||||
|
|
||||||
from enhancer.utils import Audio
|
from enhancer.utils import Audio
|
||||||
from enhancer.utils.config import DEFAULT_DEVICE
|
|
||||||
|
|
||||||
class Inference:
|
class Inference:
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,8 @@
|
||||||
from base64 import encode
|
import logging
|
||||||
from turtle import forward
|
|
||||||
from typing import Optional, Union, List
|
from typing import Optional, Union, List
|
||||||
from torch import nn
|
from torch import nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import math
|
import math
|
||||||
|
|
||||||
from enhancer.models.model import Model
|
from enhancer.models.model import Model
|
||||||
from enhancer.data.dataset import EnhancerDataset
|
from enhancer.data.dataset import EnhancerDataset
|
||||||
|
|
@ -114,6 +113,10 @@ class Demucs(Model):
|
||||||
|
|
||||||
):
|
):
|
||||||
duration = dataset.duration if isinstance(dataset,EnhancerDataset) else None
|
duration = dataset.duration if isinstance(dataset,EnhancerDataset) else None
|
||||||
|
if dataset is not None:
|
||||||
|
if sampling_rate!=dataset.sampling_rate:
|
||||||
|
logging.warn(f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}")
|
||||||
|
sampling_rate = dataset.sampling_rate
|
||||||
super().__init__(num_channels=num_channels,
|
super().__init__(num_channels=num_channels,
|
||||||
sampling_rate=sampling_rate,lr=lr,
|
sampling_rate=sampling_rate,lr=lr,
|
||||||
dataset=dataset,duration=duration,loss=loss, metric=metric)
|
dataset=dataset,duration=duration,loss=loss, metric=metric)
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
from importlib import import_module
|
from importlib import import_module
|
||||||
from huggingface_hub import cached_download, hf_hub_url
|
from huggingface_hub import cached_download, hf_hub_url
|
||||||
|
import logging
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import os
|
import os
|
||||||
from typing import Optional, Union, List, Text, Dict, Any
|
from typing import Optional, Union, List, Text, Dict, Any
|
||||||
|
|
@ -82,8 +83,10 @@ class Model(pl.LightningModule):
|
||||||
loss = self.loss(prediction, target)
|
loss = self.loss(prediction, target)
|
||||||
|
|
||||||
if self.logger:
|
if self.logger:
|
||||||
self.logger.experiment.log_metrics({"train_loss":loss.item()}, step=self.global_step)
|
self.logger.experiment.log_metric(run_id=self.logger.run_id,
|
||||||
|
key="train_loss", value=loss.item(),
|
||||||
|
step=self.global_step)
|
||||||
|
self.log("train_loss",loss.item())
|
||||||
return {"loss":loss}
|
return {"loss":loss}
|
||||||
|
|
||||||
def validation_step(self,batch,batch_idx:int):
|
def validation_step(self,batch,batch_idx:int):
|
||||||
|
|
@ -92,11 +95,20 @@ class Model(pl.LightningModule):
|
||||||
target = batch["clean"]
|
target = batch["clean"]
|
||||||
prediction = self(mixed_waveform)
|
prediction = self(mixed_waveform)
|
||||||
|
|
||||||
loss = self.metric(prediction, target)
|
metric_val = self.metric(prediction, target)
|
||||||
if self.logger:
|
loss_val = self.loss(prediction, target)
|
||||||
self.logger.experiment.log_metrics({"val_loss":loss.item()}, step=self.global_step)
|
self.log("val_metric",metric_val.item())
|
||||||
|
self.log("val_loss",loss_val.item())
|
||||||
|
|
||||||
return {"loss":loss}
|
if self.logger:
|
||||||
|
self.logger.experiment.log_metric(run_id=self.logger.run_id,
|
||||||
|
key="val_loss",value=loss_val.item(),
|
||||||
|
step=self.global_step)
|
||||||
|
self.logger.experiment.log_metric(run_id=self.logger.run_id,
|
||||||
|
key="val_metric",value=metric_val.item(),
|
||||||
|
step=self.global_step)
|
||||||
|
|
||||||
|
return {"loss":loss_val}
|
||||||
|
|
||||||
def on_save_checkpoint(self, checkpoint):
|
def on_save_checkpoint(self, checkpoint):
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,4 @@
|
||||||
from tkinter import wantobjects
|
import logging
|
||||||
import wave
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
@ -70,6 +69,11 @@ class WaveUnet(Model):
|
||||||
loss: Union[str, List] = "mse",
|
loss: Union[str, List] = "mse",
|
||||||
metric:Union[str,List] = "mse"
|
metric:Union[str,List] = "mse"
|
||||||
):
|
):
|
||||||
|
duration = dataset.duration if isinstance(dataset,EnhancerDataset) else None
|
||||||
|
if dataset is not None:
|
||||||
|
if sampling_rate!=dataset.sampling_rate:
|
||||||
|
logging.warn(f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}")
|
||||||
|
sampling_rate = dataset.sampling_rate
|
||||||
super().__init__(num_channels=num_channels,
|
super().__init__(num_channels=num_channels,
|
||||||
sampling_rate=sampling_rate,lr=lr,
|
sampling_rate=sampling_rate,lr=lr,
|
||||||
dataset=dataset,duration=duration,loss=loss, metric=metric
|
dataset=dataset,duration=duration,loss=loss, metric=metric
|
||||||
|
|
@ -125,7 +129,6 @@ class WaveUnet(Model):
|
||||||
|
|
||||||
for layer,decoder in enumerate(self.decoders):
|
for layer,decoder in enumerate(self.decoders):
|
||||||
out = F.interpolate(out, scale_factor=2, mode="linear")
|
out = F.interpolate(out, scale_factor=2, mode="linear")
|
||||||
print(out.shape,encoder_outputs[layer].shape)
|
|
||||||
out = self.fix_last_dim(out,encoder_outputs[layer])
|
out = self.fix_last_dim(out,encoder_outputs[layer])
|
||||||
out = torch.cat([out,encoder_outputs[layer]],dim=1)
|
out = torch.cat([out,encoder_outputs[layer]],dim=1)
|
||||||
out = decoder(out)
|
out = decoder(out)
|
||||||
|
|
|
||||||
|
|
@ -1,18 +1,11 @@
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Paths:
|
|
||||||
log : str
|
|
||||||
data : str
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Files:
|
class Files:
|
||||||
|
root_dir : str
|
||||||
train_clean : str
|
train_clean : str
|
||||||
train_noisy : str
|
train_noisy : str
|
||||||
test_clean : str
|
test_clean : str
|
||||||
test_noisy : str
|
test_noisy : str
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class EnhancerConfig:
|
|
||||||
path : Paths
|
|
||||||
files: Files
|
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,6 @@
|
||||||
import os
|
import os
|
||||||
import librosa
|
import librosa
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from matplotlib.pyplot import axis
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torchaudio
|
import torchaudio
|
||||||
|
|
|
||||||
|
|
@ -22,6 +22,8 @@ echo "Activate Environment"
|
||||||
source activate enhancer
|
source activate enhancer
|
||||||
export TRANSFORMERS_OFFLINE=True
|
export TRANSFORMERS_OFFLINE=True
|
||||||
export PYTHONPATH=${PYTHONPATH}:/scratch/c.sistc3/enhancer
|
export PYTHONPATH=${PYTHONPATH}:/scratch/c.sistc3/enhancer
|
||||||
|
export HYDRA_FULL_ERROR=1
|
||||||
|
|
||||||
echo $PYTHONPATH
|
echo $PYTHONPATH
|
||||||
|
|
||||||
source ~/mlflow_settings.sh
|
source ~/mlflow_settings.sh
|
||||||
|
|
|
||||||
|
|
@ -10,5 +10,6 @@ tqdm==4.64.0
|
||||||
mlflow==1.23.1
|
mlflow==1.23.1
|
||||||
protobuf==3.19.3
|
protobuf==3.19.3
|
||||||
boto3==1.23.9
|
boto3==1.23.9
|
||||||
|
torchaudio==0.10.2
|
||||||
huggingface-hub==0.4.0
|
huggingface-hub==0.4.0
|
||||||
pytorch-lightning==1.5.10
|
pytorch-lightning==1.5.10
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,5 @@
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from enhancer.inference import Inference
|
from enhancer.inference import Inference
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue