Merge pull request #1 from shahules786/dev-hawk

debug code in hawk env
This commit is contained in:
Shahul ES 2022-09-30 10:28:49 +05:30 committed by GitHub
commit ba271f8a2a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 78 additions and 70 deletions

3
.gitignore vendored
View File

@ -1,4 +1,5 @@
##local
#local
cli/train_config/dataset/Vctk_local.yaml
.DS_Store
outputs/
datasets/

View File

@ -1,17 +1,17 @@
import os
import hydra
from hydra.utils import instantiate
from omegaconf import DictConfig
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.loggers import MLFlowLogger
from enhancer.data.dataset import EnhancerDataset
@hydra.main(config_path="train_config",config_name="config")
def main(config: DictConfig):
callbacks = []
logger = MLFlowLogger(experiment_name=config.mlflow.experiment_name,
run_name=config.mlflow.run_name)
run_name=config.mlflow.run_name, tags={"JOB_ID":os.environ.get("SLURM_JOBID")})
parameters = config.hyperparameters
@ -21,12 +21,12 @@ def main(config: DictConfig):
loss=parameters.get("loss"), metric = parameters.get("metric"))
checkpoint = ModelCheckpoint(
dirpath="",filename="model",monitor=parameters.get("loss"),verbose=False,
dirpath="",filename="model",monitor="valid_loss",verbose=False,
mode="min",every_n_epochs=1
)
callbacks.append(checkpoint)
early_stopping = EarlyStopping(
monitor=parameters.get("loss"),
monitor="valid_loss",
mode="min",
min_delta=0.0,
patience=100,

View File

@ -1,7 +1,7 @@
defaults:
- model : Demucs
- model : WaveUnet
- dataset : Vctk
- optimizer : Adam
- hyperparameters : default
- trainer : fastrun_dev
- trainer : default
- mlflow : experiment

View File

@ -2,8 +2,8 @@ _target_: enhancer.data.dataset.EnhancerDataset
name : vctk
root_dir : /scratch/c.sistc3/DS_10283_2791
duration : 1.0
sampling_rate: 48000
batch_size: 32
sampling_rate: 16000
batch_size: 8
files:
train_clean : clean_trainset_56spk_wav

View File

@ -1,4 +1,4 @@
loss : mse
metric : mae
lr : 0.001
num_epochs : 10
lr : 0.0001
num_epochs : 100

View File

@ -1,2 +1,2 @@
experiment_name : "myexp"
run_name : "myrun"
experiment_name : shahules/enhancer
run_name : baseline

View File

@ -1,16 +1,15 @@
# @package _group_
_target_: pytorch_lightning.Trainer
accelerator: auto
accumulate_grad_batches: 1
amp_backend: native
auto_lr_find: False
auto_lr_find: True
auto_scale_batch_size: False
auto_select_gpus: True
benchmark: False
check_val_every_n_epoch: 1
detect_anomaly: False
deterministic: False
devices: auto
devices: -1
enable_checkpointing: True
enable_model_summary: True
enable_progress_bar: True
@ -23,8 +22,8 @@ limit_predict_batches: 1.0
limit_test_batches: 1.0
limit_train_batches: 1.0
limit_val_batches: 1.0
log_every_n_steps: 50
max_epochs: 1000
log_every_n_steps: 10
max_epochs: 100
max_steps: null
max_time: null
min_epochs: 1

View File

@ -1,3 +1,2 @@
# @package _group_
_target_: pytorch_lightning.Trainer
fast_dev_run: True

View File

@ -1,6 +1,4 @@
from dataclasses import dataclass
import glob
import multiprocessing
import math
import os
import pytorch_lightning as pl
@ -46,7 +44,8 @@ class TaskDataset(pl.LightningDataModule):
duration:float=1.0,
sampling_rate:int=48000,
matching_function = None,
batch_size=32):
batch_size=32,
num_workers:Optional[int]=None):
super().__init__()
self.name = name
@ -56,6 +55,9 @@ class TaskDataset(pl.LightningDataModule):
self.batch_size = batch_size
self.matching_function = matching_function
self._validation = []
if num_workers is None:
num_workers = multiprocessing.cpu_count()//2
self.num_workers = num_workers
def setup(self, stage: Optional[str] = None):
@ -64,15 +66,13 @@ class TaskDataset(pl.LightningDataModule):
train_clean = os.path.join(self.root_dir,self.files.train_clean)
train_noisy = os.path.join(self.root_dir,self.files.train_noisy)
fp = Fileprocessor.from_name(self.name,train_clean,
train_noisy,self.sampling_rate,
self.matching_function)
train_noisy, self.matching_function)
self.train_data = fp.prepare_matching_dict()
val_clean = os.path.join(self.root_dir,self.files.test_clean)
val_noisy = os.path.join(self.root_dir,self.files.test_noisy)
fp = Fileprocessor.from_name(self.name,val_clean,
val_noisy,self.sampling_rate,
self.matching_function)
val_noisy, self.matching_function)
val_data = fp.prepare_matching_dict()
for item in val_data:
@ -85,10 +85,10 @@ class TaskDataset(pl.LightningDataModule):
self._validation.append(({"clean":clean,"noisy":noisy},
start_time))
def train_dataloader(self):
return DataLoader(TrainDataset(self), batch_size = self.batch_size,num_workers=2)
return DataLoader(TrainDataset(self), batch_size = self.batch_size,num_workers=self.num_workers)
def val_dataloader(self):
return DataLoader(ValidDataset(self), batch_size = self.batch_size,num_workers=2)
return DataLoader(ValidDataset(self), batch_size = self.batch_size,num_workers=self.num_workers)
class EnhancerDataset(TaskDataset):
"""Dataset object for creating clean-noisy speech enhancement datasets"""
@ -101,7 +101,8 @@ class EnhancerDataset(TaskDataset):
duration=1.0,
sampling_rate=48000,
matching_function=None,
batch_size=32):
batch_size=32,
num_workers:Optional[int]=None):
super().__init__(
name=name,
@ -110,7 +111,8 @@ class EnhancerDataset(TaskDataset):
sampling_rate=sampling_rate,
duration=duration,
matching_function = matching_function,
batch_size=batch_size
batch_size=batch_size,
num_workers = num_workers,
)

View File

@ -1,12 +1,13 @@
import glob
import os
from re import S
import numpy as np
from scipy.io import wavfile
class ProcessorFunctions:
@staticmethod
def match_vtck(clean_path,noisy_path,sr):
def match_vtck(clean_path,noisy_path):
matching_wavfiles = list()
clean_filenames = [file.split('/')[-1] for file in glob.glob(os.path.join(clean_path,"*.wav"))]
@ -18,16 +19,15 @@ class ProcessorFunctions:
sr_clean, clean_file = wavfile.read(os.path.join(clean_path,file_name))
sr_noisy, noisy_file = wavfile.read(os.path.join(noisy_path,file_name))
if ((clean_file.shape[-1]==noisy_file.shape[-1]) and
(sr_clean==sr) and
(sr_noisy==sr)):
(sr_clean==sr_noisy)):
matching_wavfiles.append(
{"clean":os.path.join(clean_path,file_name),"noisy":os.path.join(noisy_path,file_name),
"duration":clean_file.shape[-1]/sr}
"duration":clean_file.shape[-1]/sr_clean}
)
return matching_wavfiles
@staticmethod
def match_dns2020(clean_path,noisy_path,sr):
def match_dns2020(clean_path,noisy_path):
matching_wavfiles = dict()
clean_filenames = [file.split('/')[-1] for file in glob.glob(os.path.join(clean_path,"*.wav"))]
@ -38,11 +38,10 @@ class ProcessorFunctions:
sr_clean, clean_file = wavfile.read(os.path.join(clean_path,clean_file))
sr_noisy, noisy_file = wavfile.read(noisy_file)
if ((clean_file.shape[-1]==noisy_file.shape[-1]) and
(sr_clean==sr) and
(sr_noisy==sr)):
(sr_clean==sr_noisy)):
matching_wavfiles.update(
{"clean":os.path.join(clean_path,clean_file),"noisy":noisy_file,
"duration":clean_file.shape[-1]/sr}
"duration":clean_file.shape[-1]/sr_clean}
)
return matching_wavfiles
@ -54,12 +53,10 @@ class Fileprocessor:
self,
clean_dir,
noisy_dir,
sr = 16000,
matching_function = None
):
self.clean_dir = clean_dir
self.noisy_dir = noisy_dir
self.sr = sr
self.matching_function = matching_function
@classmethod
@ -67,23 +64,22 @@ class Fileprocessor:
name:str,
clean_dir,
noisy_dir,
sr,
matching_function=None
):
if name.lower() == "vctk":
return cls(clean_dir,noisy_dir,sr, ProcessorFunctions.match_vtck)
return cls(clean_dir,noisy_dir, ProcessorFunctions.match_vtck)
elif name.lower() == "dns-2020":
return cls(clean_dir,noisy_dir,sr, ProcessorFunctions.match_dns2020)
return cls(clean_dir,noisy_dir, ProcessorFunctions.match_dns2020)
else:
return cls(clean_dir,noisy_dir,sr, matching_function)
return cls(clean_dir,noisy_dir, matching_function)
def prepare_matching_dict(self):
if self.matching_function is None:
raise ValueError("Not a valid matching function")
return self.matching_function(self.clean_dir,self.noisy_dir,self.sr)
return self.matching_function(self.clean_dir,self.noisy_dir)

View File

@ -10,7 +10,6 @@ from pathlib import Path
from librosa import load as load_audio
from enhancer.utils import Audio
from enhancer.utils.config import DEFAULT_DEVICE
class Inference:

View File

@ -1,9 +1,8 @@
from base64 import encode
from turtle import forward
import logging
from typing import Optional, Union, List
from torch import nn
import torch.nn.functional as F
import math
import math
from enhancer.models.model import Model
from enhancer.data.dataset import EnhancerDataset
@ -114,6 +113,10 @@ class Demucs(Model):
):
duration = dataset.duration if isinstance(dataset,EnhancerDataset) else None
if dataset is not None:
if sampling_rate!=dataset.sampling_rate:
logging.warn(f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}")
sampling_rate = dataset.sampling_rate
super().__init__(num_channels=num_channels,
sampling_rate=sampling_rate,lr=lr,
dataset=dataset,duration=duration,loss=loss, metric=metric)

View File

@ -1,5 +1,6 @@
from importlib import import_module
from huggingface_hub import cached_download, hf_hub_url
import logging
import numpy as np
import os
from typing import Optional, Union, List, Text, Dict, Any
@ -82,8 +83,10 @@ class Model(pl.LightningModule):
loss = self.loss(prediction, target)
if self.logger:
self.logger.experiment.log_metrics({"train_loss":loss.item()}, step=self.global_step)
self.logger.experiment.log_metric(run_id=self.logger.run_id,
key="train_loss", value=loss.item(),
step=self.global_step)
self.log("train_loss",loss.item())
return {"loss":loss}
def validation_step(self,batch,batch_idx:int):
@ -92,11 +95,20 @@ class Model(pl.LightningModule):
target = batch["clean"]
prediction = self(mixed_waveform)
loss = self.metric(prediction, target)
if self.logger:
self.logger.experiment.log_metrics({"val_loss":loss.item()}, step=self.global_step)
metric_val = self.metric(prediction, target)
loss_val = self.loss(prediction, target)
self.log("val_metric",metric_val.item())
self.log("val_loss",loss_val.item())
return {"loss":loss}
if self.logger:
self.logger.experiment.log_metric(run_id=self.logger.run_id,
key="val_loss",value=loss_val.item(),
step=self.global_step)
self.logger.experiment.log_metric(run_id=self.logger.run_id,
key="val_metric",value=metric_val.item(),
step=self.global_step)
return {"loss":loss_val}
def on_save_checkpoint(self, checkpoint):

View File

@ -1,5 +1,4 @@
from tkinter import wantobjects
import wave
import logging
import torch
import torch.nn as nn
import torch.nn.functional as F
@ -70,6 +69,11 @@ class WaveUnet(Model):
loss: Union[str, List] = "mse",
metric:Union[str,List] = "mse"
):
duration = dataset.duration if isinstance(dataset,EnhancerDataset) else None
if dataset is not None:
if sampling_rate!=dataset.sampling_rate:
logging.warn(f"model sampling rate {sampling_rate} should match dataset sampling rate {dataset.sampling_rate}")
sampling_rate = dataset.sampling_rate
super().__init__(num_channels=num_channels,
sampling_rate=sampling_rate,lr=lr,
dataset=dataset,duration=duration,loss=loss, metric=metric
@ -125,7 +129,6 @@ class WaveUnet(Model):
for layer,decoder in enumerate(self.decoders):
out = F.interpolate(out, scale_factor=2, mode="linear")
print(out.shape,encoder_outputs[layer].shape)
out = self.fix_last_dim(out,encoder_outputs[layer])
out = torch.cat([out,encoder_outputs[layer]],dim=1)
out = decoder(out)

View File

@ -1,18 +1,11 @@
from dataclasses import dataclass
@dataclass
class Paths:
log : str
data : str
@dataclass
class Files:
root_dir : str
train_clean : str
train_noisy : str
test_clean : str
test_noisy : str
@dataclass
class EnhancerConfig:
path : Paths
files: Files

View File

@ -1,7 +1,6 @@
import os
import librosa
from typing import Optional
from matplotlib.pyplot import axis
import numpy as np
import torch
import torchaudio

View File

@ -22,6 +22,8 @@ echo "Activate Environment"
source activate enhancer
export TRANSFORMERS_OFFLINE=True
export PYTHONPATH=${PYTHONPATH}:/scratch/c.sistc3/enhancer
export HYDRA_FULL_ERROR=1
echo $PYTHONPATH
source ~/mlflow_settings.sh

View File

@ -10,5 +10,6 @@ tqdm==4.64.0
mlflow==1.23.1
protobuf==3.19.3
boto3==1.23.9
torchaudio==0.10.2
huggingface-hub==0.4.0
pytorch-lightning==1.5.10

View File

@ -1,6 +1,5 @@
import pytest
import torch
import numpy as np
from enhancer.inference import Inference