Merge pull request #1 from shahules786/dev-hawk

debug code in hawk env
This commit is contained in:
Shahul ES 2022-09-30 10:28:49 +05:30 committed by GitHub
commit ba271f8a2a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 78 additions and 70 deletions

3
.gitignore vendored
View File

@ -1,4 +1,5 @@
##local #local
cli/train_config/dataset/Vctk_local.yaml
.DS_Store .DS_Store
outputs/ outputs/
datasets/ datasets/

View File

@ -1,17 +1,17 @@
import os
import hydra import hydra
from hydra.utils import instantiate from hydra.utils import instantiate
from omegaconf import DictConfig from omegaconf import DictConfig
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.loggers import MLFlowLogger from pytorch_lightning.loggers import MLFlowLogger
from enhancer.data.dataset import EnhancerDataset
@hydra.main(config_path="train_config",config_name="config") @hydra.main(config_path="train_config",config_name="config")
def main(config: DictConfig): def main(config: DictConfig):
callbacks = [] callbacks = []
logger = MLFlowLogger(experiment_name=config.mlflow.experiment_name, logger = MLFlowLogger(experiment_name=config.mlflow.experiment_name,
run_name=config.mlflow.run_name) run_name=config.mlflow.run_name, tags={"JOB_ID":os.environ.get("SLURM_JOBID")})
parameters = config.hyperparameters parameters = config.hyperparameters
@ -21,12 +21,12 @@ def main(config: DictConfig):
loss=parameters.get("loss"), metric = parameters.get("metric")) loss=parameters.get("loss"), metric = parameters.get("metric"))
checkpoint = ModelCheckpoint( checkpoint = ModelCheckpoint(
dirpath="",filename="model",monitor=parameters.get("loss"),verbose=False, dirpath="",filename="model",monitor="valid_loss",verbose=False,
mode="min",every_n_epochs=1 mode="min",every_n_epochs=1
) )
callbacks.append(checkpoint) callbacks.append(checkpoint)
early_stopping = EarlyStopping( early_stopping = EarlyStopping(
monitor=parameters.get("loss"), monitor="valid_loss",
mode="min", mode="min",
min_delta=0.0, min_delta=0.0,
patience=100, patience=100,

View File

@ -1,7 +1,7 @@
defaults: defaults:
- model : Demucs - model : WaveUnet
- dataset : Vctk - dataset : Vctk
- optimizer : Adam - optimizer : Adam
- hyperparameters : default - hyperparameters : default
- trainer : fastrun_dev - trainer : default
- mlflow : experiment - mlflow : experiment

View File

@ -2,8 +2,8 @@ _target_: enhancer.data.dataset.EnhancerDataset
name : vctk name : vctk
root_dir : /scratch/c.sistc3/DS_10283_2791 root_dir : /scratch/c.sistc3/DS_10283_2791
duration : 1.0 duration : 1.0
sampling_rate: 48000 sampling_rate: 16000
batch_size: 32 batch_size: 8
files: files:
train_clean : clean_trainset_56spk_wav train_clean : clean_trainset_56spk_wav

View File

@ -1,4 +1,4 @@
loss : mse loss : mse
metric : mae metric : mae
lr : 0.001 lr : 0.0001
num_epochs : 10 num_epochs : 100

View File

@ -1,2 +1,2 @@
experiment_name : "myexp" experiment_name : shahules/enhancer
run_name : "myrun" run_name : baseline

View File

@ -1,16 +1,15 @@
# @package _group_
_target_: pytorch_lightning.Trainer _target_: pytorch_lightning.Trainer
accelerator: auto accelerator: auto
accumulate_grad_batches: 1 accumulate_grad_batches: 1
amp_backend: native amp_backend: native
auto_lr_find: False auto_lr_find: True
auto_scale_batch_size: False auto_scale_batch_size: False
auto_select_gpus: True auto_select_gpus: True
benchmark: False benchmark: False
check_val_every_n_epoch: 1 check_val_every_n_epoch: 1
detect_anomaly: False detect_anomaly: False
deterministic: False deterministic: False
devices: auto devices: -1
enable_checkpointing: True enable_checkpointing: True
enable_model_summary: True enable_model_summary: True
enable_progress_bar: True enable_progress_bar: True
@ -23,8 +22,8 @@ limit_predict_batches: 1.0
limit_test_batches: 1.0 limit_test_batches: 1.0
limit_train_batches: 1.0 limit_train_batches: 1.0
limit_val_batches: 1.0 limit_val_batches: 1.0
log_every_n_steps: 50 log_every_n_steps: 10
max_epochs: 1000 max_epochs: 100
max_steps: null max_steps: null
max_time: null max_time: null
min_epochs: 1 min_epochs: 1

View File

@ -1,3 +1,2 @@
# @package _group_
_target_: pytorch_lightning.Trainer _target_: pytorch_lightning.Trainer
fast_dev_run: True fast_dev_run: True

View File

@ -1,6 +1,4 @@
import multiprocessing
from dataclasses import dataclass
import glob
import math import math
import os import os
import pytorch_lightning as pl import pytorch_lightning as pl
@ -46,7 +44,8 @@ class TaskDataset(pl.LightningDataModule):
duration:float=1.0, duration:float=1.0,
sampling_rate:int=48000, sampling_rate:int=48000,
matching_function = None, matching_function = None,
batch_size=32): batch_size=32,
num_workers:Optional[int]=None):
super().__init__() super().__init__()
self.name = name self.name = name
@ -56,6 +55,9 @@ class TaskDataset(pl.LightningDataModule):
self.batch_size = batch_size self.batch_size = batch_size
self.matching_function = matching_function self.matching_function = matching_function
self._validation = [] self._validation = []
if num_workers is None:
num_workers = multiprocessing.cpu_count()//2
self.num_workers = num_workers
def setup(self, stage: Optional[str] = None): def setup(self, stage: Optional[str] = None):
@ -64,15 +66,13 @@ class TaskDataset(pl.LightningDataModule):
train_clean = os.path.join(self.root_dir,self.files.train_clean) train_clean = os.path.join(self.root_dir,self.files.train_clean)
train_noisy = os.path.join(self.root_dir,self.files.train_noisy) train_noisy = os.path.join(self.root_dir,self.files.train_noisy)
fp = Fileprocessor.from_name(self.name,train_clean, fp = Fileprocessor.from_name(self.name,train_clean,
train_noisy,self.sampling_rate, train_noisy, self.matching_function)
self.matching_function)
self.train_data = fp.prepare_matching_dict() self.train_data = fp.prepare_matching_dict()
val_clean = os.path.join(self.root_dir,self.files.test_clean) val_clean = os.path.join(self.root_dir,self.files.test_clean)
val_noisy = os.path.join(self.root_dir,self.files.test_noisy) val_noisy = os.path.join(self.root_dir,self.files.test_noisy)
fp = Fileprocessor.from_name(self.name,val_clean, fp = Fileprocessor.from_name(self.name,val_clean,
val_noisy,self.sampling_rate, val_noisy, self.matching_function)
self.matching_function)
val_data = fp.prepare_matching_dict() val_data = fp.prepare_matching_dict()
for item in val_data: for item in val_data:
@ -85,10 +85,10 @@ class TaskDataset(pl.LightningDataModule):
self._validation.append(({"clean":clean,"noisy":noisy}, self._validation.append(({"clean":clean,"noisy":noisy},
start_time)) start_time))
def train_dataloader(self): def train_dataloader(self):
return DataLoader(TrainDataset(self), batch_size = self.batch_size,num_workers=2) return DataLoader(TrainDataset(self), batch_size = self.batch_size,num_workers=self.num_workers)
def val_dataloader(self): def val_dataloader(self):
return DataLoader(ValidDataset(self), batch_size = self.batch_size,num_workers=2) return DataLoader(ValidDataset(self), batch_size = self.batch_size,num_workers=self.num_workers)
class EnhancerDataset(TaskDataset): class EnhancerDataset(TaskDataset):
"""Dataset object for creating clean-noisy speech enhancement datasets""" """Dataset object for creating clean-noisy speech enhancement datasets"""
@ -101,7 +101,8 @@ class EnhancerDataset(TaskDataset):
duration=1.0, duration=1.0,
sampling_rate=48000, sampling_rate=48000,
matching_function=None, matching_function=None,
batch_size=32): batch_size=32,
num_workers:Optional[int]=None):
super().__init__( super().__init__(
name=name, name=name,
@ -110,7 +111,8 @@ class EnhancerDataset(TaskDataset):
sampling_rate=sampling_rate, sampling_rate=sampling_rate,
duration=duration, duration=duration,
matching_function = matching_function, matching_function = matching_function,
batch_size=batch_size batch_size=batch_size,
num_workers = num_workers,
) )

View File

@ -1,12 +1,13 @@
import glob import glob
import os import os
from re import S
import numpy as np import numpy as np
from scipy.io import wavfile from scipy.io import wavfile
class ProcessorFunctions: class ProcessorFunctions:
@staticmethod @staticmethod
def match_vtck(clean_path,noisy_path,sr): def match_vtck(clean_path,noisy_path):
matching_wavfiles = list() matching_wavfiles = list()
clean_filenames = [file.split('/')[-1] for file in glob.glob(os.path.join(clean_path,"*.wav"))] clean_filenames = [file.split('/')[-1] for file in glob.glob(os.path.join(clean_path,"*.wav"))]
@ -18,16 +19,15 @@ class ProcessorFunctions:
sr_clean, clean_file = wavfile.read(os.path.join(clean_path,file_name)) sr_clean, clean_file = wavfile.read(os.path.join(clean_path,file_name))
sr_noisy, noisy_file = wavfile.read(os.path.join(noisy_path,file_name)) sr_noisy, noisy_file = wavfile.read(os.path.join(noisy_path,file_name))
if ((clean_file.shape[-1]==noisy_file.shape[-1]) and if ((clean_file.shape[-1]==noisy_file.shape[-1]) and
(sr_clean==sr) and (sr_clean==sr_noisy)):
(sr_noisy==sr)):
matching_wavfiles.append( matching_wavfiles.append(
{"clean":os.path.join(clean_path,file_name),"noisy":os.path.join(noisy_path,file_name), {"clean":os.path.join(clean_path,file_name),"noisy":os.path.join(noisy_path,file_name),
"duration":clean_file.shape[-1]/sr} "duration":clean_file.shape[-1]/sr_clean}
) )
return matching_wavfiles return matching_wavfiles
@staticmethod @staticmethod
def match_dns2020(clean_path,noisy_path,sr): def match_dns2020(clean_path,noisy_path):
matching_wavfiles = dict() matching_wavfiles = dict()
clean_filenames = [file.split('/')[-1] for file in glob.glob(os.path.join(clean_path,"*.wav"))] clean_filenames = [file.split('/')[-1] for file in glob.glob(os.path.join(clean_path,"*.wav"))]
@ -38,11 +38,10 @@ class ProcessorFunctions:
sr_clean, clean_file = wavfile.read(os.path.join(clean_path,clean_file)) sr_clean, clean_file = wavfile.read(os.path.join(clean_path,clean_file))
sr_noisy, noisy_file = wavfile.read(noisy_file) sr_noisy, noisy_file = wavfile.read(noisy_file)
if ((clean_file.shape[-1]==noisy_file.shape[-1]) and if ((clean_file.shape[-1]==noisy_file.shape[-1]) and
(sr_clean==sr) and (sr_clean==sr_noisy)):
(sr_noisy==sr)):
matching_wavfiles.update( matching_wavfiles.update(
{"clean":os.path.join(clean_path,clean_file),"noisy":noisy_file, {"clean":os.path.join(clean_path,clean_file),"noisy":noisy_file,
"duration":clean_file.shape[-1]/sr} "duration":clean_file.shape[-1]/sr_clean}
) )
return matching_wavfiles return matching_wavfiles
@ -54,12 +53,10 @@ class Fileprocessor:
self, self,
clean_dir, clean_dir,
noisy_dir, noisy_dir,
sr = 16000,
matching_function = None matching_function = None
): ):
self.clean_dir = clean_dir self.clean_dir = clean_dir
self.noisy_dir = noisy_dir self.noisy_dir = noisy_dir
self.sr = sr
self.matching_function = matching_function self.matching_function = matching_function
@classmethod @classmethod
@ -67,23 +64,22 @@ class Fileprocessor:
name:str, name:str,
clean_dir, clean_dir,
noisy_dir, noisy_dir,
sr,
matching_function=None matching_function=None
): ):
if name.lower() == "vctk": if name.lower() == "vctk":
return cls(clean_dir,noisy_dir,sr, ProcessorFunctions.match_vtck) return cls(clean_dir,noisy_dir, ProcessorFunctions.match_vtck)
elif name.lower() == "dns-2020": elif name.lower() == "dns-2020":
return cls(clean_dir,noisy_dir,sr, ProcessorFunctions.match_dns2020) return cls(clean_dir,noisy_dir, ProcessorFunctions.match_dns2020)
else: else:
return cls(clean_dir,noisy_dir,sr, matching_function) return cls(clean_dir,noisy_dir, matching_function)
def prepare_matching_dict(self): def prepare_matching_dict(self):
if self.matching_function is None: if self.matching_function is None:
raise ValueError("Not a valid matching function") raise ValueError("Not a valid matching function")
return self.matching_function(self.clean_dir,self.noisy_dir,self.sr) return self.matching_function(self.clean_dir,self.noisy_dir)

View File

@ -10,7 +10,6 @@ from pathlib import Path
from librosa import load as load_audio from librosa import load as load_audio
from enhancer.utils import Audio from enhancer.utils import Audio
from enhancer.utils.config import DEFAULT_DEVICE
class Inference: class Inference:

View File

@ -1,5 +1,4 @@
from base64 import encode import logging
from turtle import forward
from typing import Optional, Union, List from typing import Optional, Union, List
from torch import nn from torch import nn
import torch.nn.functional as F import torch.nn.functional as F
@ -114,6 +113,10 @@ class Demucs(Model):
): ):
duration = dataset.duration if isinstance(dataset,EnhancerDataset) else None duration = dataset.duration if isinstance(dataset,EnhancerDataset) else None
if dataset is not None:
if sampling_rate!=dataset.sampling_rate:
logging.warn(f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}")
sampling_rate = dataset.sampling_rate
super().__init__(num_channels=num_channels, super().__init__(num_channels=num_channels,
sampling_rate=sampling_rate,lr=lr, sampling_rate=sampling_rate,lr=lr,
dataset=dataset,duration=duration,loss=loss, metric=metric) dataset=dataset,duration=duration,loss=loss, metric=metric)

View File

@ -1,5 +1,6 @@
from importlib import import_module from importlib import import_module
from huggingface_hub import cached_download, hf_hub_url from huggingface_hub import cached_download, hf_hub_url
import logging
import numpy as np import numpy as np
import os import os
from typing import Optional, Union, List, Text, Dict, Any from typing import Optional, Union, List, Text, Dict, Any
@ -82,8 +83,10 @@ class Model(pl.LightningModule):
loss = self.loss(prediction, target) loss = self.loss(prediction, target)
if self.logger: if self.logger:
self.logger.experiment.log_metrics({"train_loss":loss.item()}, step=self.global_step) self.logger.experiment.log_metric(run_id=self.logger.run_id,
key="train_loss", value=loss.item(),
step=self.global_step)
self.log("train_loss",loss.item())
return {"loss":loss} return {"loss":loss}
def validation_step(self,batch,batch_idx:int): def validation_step(self,batch,batch_idx:int):
@ -92,11 +95,20 @@ class Model(pl.LightningModule):
target = batch["clean"] target = batch["clean"]
prediction = self(mixed_waveform) prediction = self(mixed_waveform)
loss = self.metric(prediction, target) metric_val = self.metric(prediction, target)
if self.logger: loss_val = self.loss(prediction, target)
self.logger.experiment.log_metrics({"val_loss":loss.item()}, step=self.global_step) self.log("val_metric",metric_val.item())
self.log("val_loss",loss_val.item())
return {"loss":loss} if self.logger:
self.logger.experiment.log_metric(run_id=self.logger.run_id,
key="val_loss",value=loss_val.item(),
step=self.global_step)
self.logger.experiment.log_metric(run_id=self.logger.run_id,
key="val_metric",value=metric_val.item(),
step=self.global_step)
return {"loss":loss_val}
def on_save_checkpoint(self, checkpoint): def on_save_checkpoint(self, checkpoint):

View File

@ -1,5 +1,4 @@
from tkinter import wantobjects import logging
import wave
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
@ -70,6 +69,11 @@ class WaveUnet(Model):
loss: Union[str, List] = "mse", loss: Union[str, List] = "mse",
metric:Union[str,List] = "mse" metric:Union[str,List] = "mse"
): ):
duration = dataset.duration if isinstance(dataset,EnhancerDataset) else None
if dataset is not None:
if sampling_rate!=dataset.sampling_rate:
logging.warn(f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}")
sampling_rate = dataset.sampling_rate
super().__init__(num_channels=num_channels, super().__init__(num_channels=num_channels,
sampling_rate=sampling_rate,lr=lr, sampling_rate=sampling_rate,lr=lr,
dataset=dataset,duration=duration,loss=loss, metric=metric dataset=dataset,duration=duration,loss=loss, metric=metric
@ -125,7 +129,6 @@ class WaveUnet(Model):
for layer,decoder in enumerate(self.decoders): for layer,decoder in enumerate(self.decoders):
out = F.interpolate(out, scale_factor=2, mode="linear") out = F.interpolate(out, scale_factor=2, mode="linear")
print(out.shape,encoder_outputs[layer].shape)
out = self.fix_last_dim(out,encoder_outputs[layer]) out = self.fix_last_dim(out,encoder_outputs[layer])
out = torch.cat([out,encoder_outputs[layer]],dim=1) out = torch.cat([out,encoder_outputs[layer]],dim=1)
out = decoder(out) out = decoder(out)

View File

@ -1,18 +1,11 @@
from dataclasses import dataclass from dataclasses import dataclass
@dataclass
class Paths:
log : str
data : str
@dataclass @dataclass
class Files: class Files:
root_dir : str
train_clean : str train_clean : str
train_noisy : str train_noisy : str
test_clean : str test_clean : str
test_noisy : str test_noisy : str
@dataclass
class EnhancerConfig:
path : Paths
files: Files

View File

@ -1,7 +1,6 @@
import os import os
import librosa import librosa
from typing import Optional from typing import Optional
from matplotlib.pyplot import axis
import numpy as np import numpy as np
import torch import torch
import torchaudio import torchaudio

View File

@ -22,6 +22,8 @@ echo "Activate Environment"
source activate enhancer source activate enhancer
export TRANSFORMERS_OFFLINE=True export TRANSFORMERS_OFFLINE=True
export PYTHONPATH=${PYTHONPATH}:/scratch/c.sistc3/enhancer export PYTHONPATH=${PYTHONPATH}:/scratch/c.sistc3/enhancer
export HYDRA_FULL_ERROR=1
echo $PYTHONPATH echo $PYTHONPATH
source ~/mlflow_settings.sh source ~/mlflow_settings.sh

View File

@ -10,5 +10,6 @@ tqdm==4.64.0
mlflow==1.23.1 mlflow==1.23.1
protobuf==3.19.3 protobuf==3.19.3
boto3==1.23.9 boto3==1.23.9
torchaudio==0.10.2
huggingface-hub==0.4.0 huggingface-hub==0.4.0
pytorch-lightning==1.5.10 pytorch-lightning==1.5.10

View File

@ -1,6 +1,5 @@
import pytest import pytest
import torch import torch
import numpy as np
from enhancer.inference import Inference from enhancer.inference import Inference