commit
ba271f8a2a
|
|
@ -1,4 +1,5 @@
|
|||
##local
|
||||
#local
|
||||
cli/train_config/dataset/Vctk_local.yaml
|
||||
.DS_Store
|
||||
outputs/
|
||||
datasets/
|
||||
|
|
|
|||
|
|
@ -1,17 +1,17 @@
|
|||
import os
|
||||
import hydra
|
||||
from hydra.utils import instantiate
|
||||
from omegaconf import DictConfig
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
|
||||
from pytorch_lightning.loggers import MLFlowLogger
|
||||
|
||||
from enhancer.data.dataset import EnhancerDataset
|
||||
|
||||
@hydra.main(config_path="train_config",config_name="config")
|
||||
def main(config: DictConfig):
|
||||
|
||||
callbacks = []
|
||||
logger = MLFlowLogger(experiment_name=config.mlflow.experiment_name,
|
||||
run_name=config.mlflow.run_name)
|
||||
run_name=config.mlflow.run_name, tags={"JOB_ID":os.environ.get("SLURM_JOBID")})
|
||||
|
||||
|
||||
parameters = config.hyperparameters
|
||||
|
|
@ -21,12 +21,12 @@ def main(config: DictConfig):
|
|||
loss=parameters.get("loss"), metric = parameters.get("metric"))
|
||||
|
||||
checkpoint = ModelCheckpoint(
|
||||
dirpath="",filename="model",monitor=parameters.get("loss"),verbose=False,
|
||||
dirpath="",filename="model",monitor="valid_loss",verbose=False,
|
||||
mode="min",every_n_epochs=1
|
||||
)
|
||||
callbacks.append(checkpoint)
|
||||
early_stopping = EarlyStopping(
|
||||
monitor=parameters.get("loss"),
|
||||
monitor="valid_loss",
|
||||
mode="min",
|
||||
min_delta=0.0,
|
||||
patience=100,
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
defaults:
|
||||
- model : Demucs
|
||||
- model : WaveUnet
|
||||
- dataset : Vctk
|
||||
- optimizer : Adam
|
||||
- hyperparameters : default
|
||||
- trainer : fastrun_dev
|
||||
- trainer : default
|
||||
- mlflow : experiment
|
||||
|
|
@ -2,8 +2,8 @@ _target_: enhancer.data.dataset.EnhancerDataset
|
|||
name : vctk
|
||||
root_dir : /scratch/c.sistc3/DS_10283_2791
|
||||
duration : 1.0
|
||||
sampling_rate: 48000
|
||||
batch_size: 32
|
||||
sampling_rate: 16000
|
||||
batch_size: 8
|
||||
|
||||
files:
|
||||
train_clean : clean_trainset_56spk_wav
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
loss : mse
|
||||
metric : mae
|
||||
lr : 0.001
|
||||
num_epochs : 10
|
||||
lr : 0.0001
|
||||
num_epochs : 100
|
||||
|
|
|
|||
|
|
@ -1,2 +1,2 @@
|
|||
experiment_name : "myexp"
|
||||
run_name : "myrun"
|
||||
experiment_name : shahules/enhancer
|
||||
run_name : baseline
|
||||
|
|
@ -1,16 +1,15 @@
|
|||
# @package _group_
|
||||
_target_: pytorch_lightning.Trainer
|
||||
accelerator: auto
|
||||
accumulate_grad_batches: 1
|
||||
amp_backend: native
|
||||
auto_lr_find: False
|
||||
auto_lr_find: True
|
||||
auto_scale_batch_size: False
|
||||
auto_select_gpus: True
|
||||
benchmark: False
|
||||
check_val_every_n_epoch: 1
|
||||
detect_anomaly: False
|
||||
deterministic: False
|
||||
devices: auto
|
||||
devices: -1
|
||||
enable_checkpointing: True
|
||||
enable_model_summary: True
|
||||
enable_progress_bar: True
|
||||
|
|
@ -23,8 +22,8 @@ limit_predict_batches: 1.0
|
|||
limit_test_batches: 1.0
|
||||
limit_train_batches: 1.0
|
||||
limit_val_batches: 1.0
|
||||
log_every_n_steps: 50
|
||||
max_epochs: 1000
|
||||
log_every_n_steps: 10
|
||||
max_epochs: 100
|
||||
max_steps: null
|
||||
max_time: null
|
||||
min_epochs: 1
|
||||
|
|
@ -1,3 +1,2 @@
|
|||
# @package _group_
|
||||
_target_: pytorch_lightning.Trainer
|
||||
fast_dev_run: True
|
||||
|
|
|
|||
|
|
@ -1,6 +1,4 @@
|
|||
|
||||
from dataclasses import dataclass
|
||||
import glob
|
||||
import multiprocessing
|
||||
import math
|
||||
import os
|
||||
import pytorch_lightning as pl
|
||||
|
|
@ -46,7 +44,8 @@ class TaskDataset(pl.LightningDataModule):
|
|||
duration:float=1.0,
|
||||
sampling_rate:int=48000,
|
||||
matching_function = None,
|
||||
batch_size=32):
|
||||
batch_size=32,
|
||||
num_workers:Optional[int]=None):
|
||||
super().__init__()
|
||||
|
||||
self.name = name
|
||||
|
|
@ -56,6 +55,9 @@ class TaskDataset(pl.LightningDataModule):
|
|||
self.batch_size = batch_size
|
||||
self.matching_function = matching_function
|
||||
self._validation = []
|
||||
if num_workers is None:
|
||||
num_workers = multiprocessing.cpu_count()//2
|
||||
self.num_workers = num_workers
|
||||
|
||||
def setup(self, stage: Optional[str] = None):
|
||||
|
||||
|
|
@ -64,15 +66,13 @@ class TaskDataset(pl.LightningDataModule):
|
|||
train_clean = os.path.join(self.root_dir,self.files.train_clean)
|
||||
train_noisy = os.path.join(self.root_dir,self.files.train_noisy)
|
||||
fp = Fileprocessor.from_name(self.name,train_clean,
|
||||
train_noisy,self.sampling_rate,
|
||||
self.matching_function)
|
||||
train_noisy, self.matching_function)
|
||||
self.train_data = fp.prepare_matching_dict()
|
||||
|
||||
val_clean = os.path.join(self.root_dir,self.files.test_clean)
|
||||
val_noisy = os.path.join(self.root_dir,self.files.test_noisy)
|
||||
fp = Fileprocessor.from_name(self.name,val_clean,
|
||||
val_noisy,self.sampling_rate,
|
||||
self.matching_function)
|
||||
val_noisy, self.matching_function)
|
||||
val_data = fp.prepare_matching_dict()
|
||||
|
||||
for item in val_data:
|
||||
|
|
@ -85,10 +85,10 @@ class TaskDataset(pl.LightningDataModule):
|
|||
self._validation.append(({"clean":clean,"noisy":noisy},
|
||||
start_time))
|
||||
def train_dataloader(self):
|
||||
return DataLoader(TrainDataset(self), batch_size = self.batch_size,num_workers=2)
|
||||
return DataLoader(TrainDataset(self), batch_size = self.batch_size,num_workers=self.num_workers)
|
||||
|
||||
def val_dataloader(self):
|
||||
return DataLoader(ValidDataset(self), batch_size = self.batch_size,num_workers=2)
|
||||
return DataLoader(ValidDataset(self), batch_size = self.batch_size,num_workers=self.num_workers)
|
||||
|
||||
class EnhancerDataset(TaskDataset):
|
||||
"""Dataset object for creating clean-noisy speech enhancement datasets"""
|
||||
|
|
@ -101,7 +101,8 @@ class EnhancerDataset(TaskDataset):
|
|||
duration=1.0,
|
||||
sampling_rate=48000,
|
||||
matching_function=None,
|
||||
batch_size=32):
|
||||
batch_size=32,
|
||||
num_workers:Optional[int]=None):
|
||||
|
||||
super().__init__(
|
||||
name=name,
|
||||
|
|
@ -110,7 +111,8 @@ class EnhancerDataset(TaskDataset):
|
|||
sampling_rate=sampling_rate,
|
||||
duration=duration,
|
||||
matching_function = matching_function,
|
||||
batch_size=batch_size
|
||||
batch_size=batch_size,
|
||||
num_workers = num_workers,
|
||||
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,12 +1,13 @@
|
|||
import glob
|
||||
import os
|
||||
from re import S
|
||||
import numpy as np
|
||||
from scipy.io import wavfile
|
||||
|
||||
class ProcessorFunctions:
|
||||
|
||||
@staticmethod
|
||||
def match_vtck(clean_path,noisy_path,sr):
|
||||
def match_vtck(clean_path,noisy_path):
|
||||
|
||||
matching_wavfiles = list()
|
||||
clean_filenames = [file.split('/')[-1] for file in glob.glob(os.path.join(clean_path,"*.wav"))]
|
||||
|
|
@ -18,16 +19,15 @@ class ProcessorFunctions:
|
|||
sr_clean, clean_file = wavfile.read(os.path.join(clean_path,file_name))
|
||||
sr_noisy, noisy_file = wavfile.read(os.path.join(noisy_path,file_name))
|
||||
if ((clean_file.shape[-1]==noisy_file.shape[-1]) and
|
||||
(sr_clean==sr) and
|
||||
(sr_noisy==sr)):
|
||||
(sr_clean==sr_noisy)):
|
||||
matching_wavfiles.append(
|
||||
{"clean":os.path.join(clean_path,file_name),"noisy":os.path.join(noisy_path,file_name),
|
||||
"duration":clean_file.shape[-1]/sr}
|
||||
"duration":clean_file.shape[-1]/sr_clean}
|
||||
)
|
||||
return matching_wavfiles
|
||||
|
||||
@staticmethod
|
||||
def match_dns2020(clean_path,noisy_path,sr):
|
||||
def match_dns2020(clean_path,noisy_path):
|
||||
|
||||
matching_wavfiles = dict()
|
||||
clean_filenames = [file.split('/')[-1] for file in glob.glob(os.path.join(clean_path,"*.wav"))]
|
||||
|
|
@ -38,11 +38,10 @@ class ProcessorFunctions:
|
|||
sr_clean, clean_file = wavfile.read(os.path.join(clean_path,clean_file))
|
||||
sr_noisy, noisy_file = wavfile.read(noisy_file)
|
||||
if ((clean_file.shape[-1]==noisy_file.shape[-1]) and
|
||||
(sr_clean==sr) and
|
||||
(sr_noisy==sr)):
|
||||
(sr_clean==sr_noisy)):
|
||||
matching_wavfiles.update(
|
||||
{"clean":os.path.join(clean_path,clean_file),"noisy":noisy_file,
|
||||
"duration":clean_file.shape[-1]/sr}
|
||||
"duration":clean_file.shape[-1]/sr_clean}
|
||||
)
|
||||
|
||||
return matching_wavfiles
|
||||
|
|
@ -54,12 +53,10 @@ class Fileprocessor:
|
|||
self,
|
||||
clean_dir,
|
||||
noisy_dir,
|
||||
sr = 16000,
|
||||
matching_function = None
|
||||
):
|
||||
self.clean_dir = clean_dir
|
||||
self.noisy_dir = noisy_dir
|
||||
self.sr = sr
|
||||
self.matching_function = matching_function
|
||||
|
||||
@classmethod
|
||||
|
|
@ -67,23 +64,22 @@ class Fileprocessor:
|
|||
name:str,
|
||||
clean_dir,
|
||||
noisy_dir,
|
||||
sr,
|
||||
matching_function=None
|
||||
):
|
||||
|
||||
if name.lower() == "vctk":
|
||||
return cls(clean_dir,noisy_dir,sr, ProcessorFunctions.match_vtck)
|
||||
return cls(clean_dir,noisy_dir, ProcessorFunctions.match_vtck)
|
||||
elif name.lower() == "dns-2020":
|
||||
return cls(clean_dir,noisy_dir,sr, ProcessorFunctions.match_dns2020)
|
||||
return cls(clean_dir,noisy_dir, ProcessorFunctions.match_dns2020)
|
||||
else:
|
||||
return cls(clean_dir,noisy_dir,sr, matching_function)
|
||||
return cls(clean_dir,noisy_dir, matching_function)
|
||||
|
||||
def prepare_matching_dict(self):
|
||||
|
||||
if self.matching_function is None:
|
||||
raise ValueError("Not a valid matching function")
|
||||
|
||||
return self.matching_function(self.clean_dir,self.noisy_dir,self.sr)
|
||||
return self.matching_function(self.clean_dir,self.noisy_dir)
|
||||
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -10,7 +10,6 @@ from pathlib import Path
|
|||
from librosa import load as load_audio
|
||||
|
||||
from enhancer.utils import Audio
|
||||
from enhancer.utils.config import DEFAULT_DEVICE
|
||||
|
||||
class Inference:
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
from base64 import encode
|
||||
from turtle import forward
|
||||
import logging
|
||||
from typing import Optional, Union, List
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
|
@ -114,6 +113,10 @@ class Demucs(Model):
|
|||
|
||||
):
|
||||
duration = dataset.duration if isinstance(dataset,EnhancerDataset) else None
|
||||
if dataset is not None:
|
||||
if sampling_rate!=dataset.sampling_rate:
|
||||
logging.warn(f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}")
|
||||
sampling_rate = dataset.sampling_rate
|
||||
super().__init__(num_channels=num_channels,
|
||||
sampling_rate=sampling_rate,lr=lr,
|
||||
dataset=dataset,duration=duration,loss=loss, metric=metric)
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
from importlib import import_module
|
||||
from huggingface_hub import cached_download, hf_hub_url
|
||||
import logging
|
||||
import numpy as np
|
||||
import os
|
||||
from typing import Optional, Union, List, Text, Dict, Any
|
||||
|
|
@ -82,8 +83,10 @@ class Model(pl.LightningModule):
|
|||
loss = self.loss(prediction, target)
|
||||
|
||||
if self.logger:
|
||||
self.logger.experiment.log_metrics({"train_loss":loss.item()}, step=self.global_step)
|
||||
|
||||
self.logger.experiment.log_metric(run_id=self.logger.run_id,
|
||||
key="train_loss", value=loss.item(),
|
||||
step=self.global_step)
|
||||
self.log("train_loss",loss.item())
|
||||
return {"loss":loss}
|
||||
|
||||
def validation_step(self,batch,batch_idx:int):
|
||||
|
|
@ -92,11 +95,20 @@ class Model(pl.LightningModule):
|
|||
target = batch["clean"]
|
||||
prediction = self(mixed_waveform)
|
||||
|
||||
loss = self.metric(prediction, target)
|
||||
if self.logger:
|
||||
self.logger.experiment.log_metrics({"val_loss":loss.item()}, step=self.global_step)
|
||||
metric_val = self.metric(prediction, target)
|
||||
loss_val = self.loss(prediction, target)
|
||||
self.log("val_metric",metric_val.item())
|
||||
self.log("val_loss",loss_val.item())
|
||||
|
||||
return {"loss":loss}
|
||||
if self.logger:
|
||||
self.logger.experiment.log_metric(run_id=self.logger.run_id,
|
||||
key="val_loss",value=loss_val.item(),
|
||||
step=self.global_step)
|
||||
self.logger.experiment.log_metric(run_id=self.logger.run_id,
|
||||
key="val_metric",value=metric_val.item(),
|
||||
step=self.global_step)
|
||||
|
||||
return {"loss":loss_val}
|
||||
|
||||
def on_save_checkpoint(self, checkpoint):
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
from tkinter import wantobjects
|
||||
import wave
|
||||
import logging
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
|
@ -70,6 +69,11 @@ class WaveUnet(Model):
|
|||
loss: Union[str, List] = "mse",
|
||||
metric:Union[str,List] = "mse"
|
||||
):
|
||||
duration = dataset.duration if isinstance(dataset,EnhancerDataset) else None
|
||||
if dataset is not None:
|
||||
if sampling_rate!=dataset.sampling_rate:
|
||||
logging.warn(f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}")
|
||||
sampling_rate = dataset.sampling_rate
|
||||
super().__init__(num_channels=num_channels,
|
||||
sampling_rate=sampling_rate,lr=lr,
|
||||
dataset=dataset,duration=duration,loss=loss, metric=metric
|
||||
|
|
@ -125,7 +129,6 @@ class WaveUnet(Model):
|
|||
|
||||
for layer,decoder in enumerate(self.decoders):
|
||||
out = F.interpolate(out, scale_factor=2, mode="linear")
|
||||
print(out.shape,encoder_outputs[layer].shape)
|
||||
out = self.fix_last_dim(out,encoder_outputs[layer])
|
||||
out = torch.cat([out,encoder_outputs[layer]],dim=1)
|
||||
out = decoder(out)
|
||||
|
|
|
|||
|
|
@ -1,18 +1,11 @@
|
|||
from dataclasses import dataclass
|
||||
|
||||
@dataclass
|
||||
class Paths:
|
||||
log : str
|
||||
data : str
|
||||
|
||||
@dataclass
|
||||
class Files:
|
||||
root_dir : str
|
||||
train_clean : str
|
||||
train_noisy : str
|
||||
test_clean : str
|
||||
test_noisy : str
|
||||
|
||||
@dataclass
|
||||
class EnhancerConfig:
|
||||
path : Paths
|
||||
files: Files
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
import os
|
||||
import librosa
|
||||
from typing import Optional
|
||||
from matplotlib.pyplot import axis
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchaudio
|
||||
|
|
|
|||
|
|
@ -22,6 +22,8 @@ echo "Activate Environment"
|
|||
source activate enhancer
|
||||
export TRANSFORMERS_OFFLINE=True
|
||||
export PYTHONPATH=${PYTHONPATH}:/scratch/c.sistc3/enhancer
|
||||
export HYDRA_FULL_ERROR=1
|
||||
|
||||
echo $PYTHONPATH
|
||||
|
||||
source ~/mlflow_settings.sh
|
||||
|
|
|
|||
|
|
@ -10,5 +10,6 @@ tqdm==4.64.0
|
|||
mlflow==1.23.1
|
||||
protobuf==3.19.3
|
||||
boto3==1.23.9
|
||||
torchaudio==0.10.2
|
||||
huggingface-hub==0.4.0
|
||||
pytorch-lightning==1.5.10
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
import pytest
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
from enhancer.inference import Inference
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue