Merge pull request #17 from shahules786/dev-datafix
foolproof iteration
This commit is contained in:
		
						commit
						a1445b0a95
					
				|  | @ -4,10 +4,16 @@ from types import MethodType | ||||||
| import hydra | import hydra | ||||||
| from hydra.utils import instantiate | from hydra.utils import instantiate | ||||||
| from omegaconf import DictConfig, OmegaConf | from omegaconf import DictConfig, OmegaConf | ||||||
| from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint | from pytorch_lightning.callbacks import ( | ||||||
|  |     EarlyStopping, | ||||||
|  |     LearningRateMonitor, | ||||||
|  |     ModelCheckpoint, | ||||||
|  | ) | ||||||
| from pytorch_lightning.loggers import MLFlowLogger | from pytorch_lightning.loggers import MLFlowLogger | ||||||
| from torch.optim.lr_scheduler import ReduceLROnPlateau | from torch.optim.lr_scheduler import ReduceLROnPlateau | ||||||
| 
 | 
 | ||||||
|  | # from torch_audiomentations import Compose, Shift | ||||||
|  | 
 | ||||||
| os.environ["HYDRA_FULL_ERROR"] = "1" | os.environ["HYDRA_FULL_ERROR"] = "1" | ||||||
| JOB_ID = os.environ.get("SLURM_JOBID", "0") | JOB_ID = os.environ.get("SLURM_JOBID", "0") | ||||||
| 
 | 
 | ||||||
|  | @ -25,8 +31,13 @@ def main(config: DictConfig): | ||||||
|     ) |     ) | ||||||
| 
 | 
 | ||||||
|     parameters = config.hyperparameters |     parameters = config.hyperparameters | ||||||
|  |     # apply_augmentations = Compose( | ||||||
|  |     #     [ | ||||||
|  |     #         Shift(min_shift=0.5, max_shift=1.0, shift_unit="seconds", p=0.5), | ||||||
|  |     #     ] | ||||||
|  |     # ) | ||||||
| 
 | 
 | ||||||
|     dataset = instantiate(config.dataset) |     dataset = instantiate(config.dataset, augmentations=None) | ||||||
|     model = instantiate( |     model = instantiate( | ||||||
|         config.model, |         config.model, | ||||||
|         dataset=dataset, |         dataset=dataset, | ||||||
|  | @ -45,6 +56,8 @@ def main(config: DictConfig): | ||||||
|         every_n_epochs=1, |         every_n_epochs=1, | ||||||
|     ) |     ) | ||||||
|     callbacks.append(checkpoint) |     callbacks.append(checkpoint) | ||||||
|  |     callbacks.append(LearningRateMonitor(logging_interval="epoch")) | ||||||
|  | 
 | ||||||
|     if parameters.get("Early_stop", False): |     if parameters.get("Early_stop", False): | ||||||
|         early_stopping = EarlyStopping( |         early_stopping = EarlyStopping( | ||||||
|             monitor="val_loss", |             monitor="val_loss", | ||||||
|  | @ -56,11 +69,11 @@ def main(config: DictConfig): | ||||||
|         ) |         ) | ||||||
|         callbacks.append(early_stopping) |         callbacks.append(early_stopping) | ||||||
| 
 | 
 | ||||||
|     def configure_optimizer(self): |     def configure_optimizers(self): | ||||||
|         optimizer = instantiate( |         optimizer = instantiate( | ||||||
|             config.optimizer, |             config.optimizer, | ||||||
|             lr=parameters.get("lr"), |             lr=parameters.get("lr"), | ||||||
|             parameters=self.parameters(), |             params=self.parameters(), | ||||||
|         ) |         ) | ||||||
|         scheduler = ReduceLROnPlateau( |         scheduler = ReduceLROnPlateau( | ||||||
|             optimizer=optimizer, |             optimizer=optimizer, | ||||||
|  | @ -70,9 +83,13 @@ def main(config: DictConfig): | ||||||
|             min_lr=parameters.get("min_lr", 1e-6), |             min_lr=parameters.get("min_lr", 1e-6), | ||||||
|             patience=parameters.get("ReduceLr_patience", 3), |             patience=parameters.get("ReduceLr_patience", 3), | ||||||
|         ) |         ) | ||||||
|         return {"optimizer": optimizer, "lr_scheduler": scheduler} |         return { | ||||||
|  |             "optimizer": optimizer, | ||||||
|  |             "lr_scheduler": scheduler, | ||||||
|  |             "monitor": f'valid_{parameters.get("ReduceLr_monitor", "loss")}', | ||||||
|  |         } | ||||||
| 
 | 
 | ||||||
|     model.configure_parameters = MethodType(configure_optimizer, model) |     model.configure_optimizers = MethodType(configure_optimizers, model) | ||||||
| 
 | 
 | ||||||
|     trainer = instantiate(config.trainer, logger=logger, callbacks=callbacks) |     trainer = instantiate(config.trainer, logger=logger, callbacks=callbacks) | ||||||
|     trainer.fit(model) |     trainer.fit(model) | ||||||
|  |  | ||||||
|  | @ -2,10 +2,10 @@ _target_: enhancer.data.dataset.EnhancerDataset | ||||||
| name : vctk | name : vctk | ||||||
| root_dir : /scratch/c.sistc3/DS_10283_2791 | root_dir : /scratch/c.sistc3/DS_10283_2791 | ||||||
| duration : 4.5 | duration : 4.5 | ||||||
| stride : 0.5 | stride : 2 | ||||||
| sampling_rate: 16000 | sampling_rate: 16000 | ||||||
| batch_size: 4 | batch_size: 32 | ||||||
| valid_minutes : 1 | valid_minutes : 15 | ||||||
| files: | files: | ||||||
|   train_clean : clean_trainset_28spk_wav |   train_clean : clean_trainset_28spk_wav | ||||||
|   test_clean : clean_testset_wav |   test_clean : clean_testset_wav | ||||||
|  |  | ||||||
|  | @ -1,7 +1,7 @@ | ||||||
| loss : mse | loss : mae | ||||||
| metric : mae | metric : [stoi,pesq,si-sdr] | ||||||
| lr : 0.0001 | lr : 0.0003 | ||||||
| ReduceLr_patience : 5 | ReduceLr_patience : 5 | ||||||
| ReduceLr_factor : 0.1 | ReduceLr_factor : 0.2 | ||||||
| min_lr : 0.000001 | min_lr : 0.000001 | ||||||
| EarlyStopping_factor : 10 | EarlyStopping_factor : 10 | ||||||
|  |  | ||||||
|  | @ -1,2 +1,2 @@ | ||||||
| experiment_name : shahules/enhancer | experiment_name : shahules/enhancer | ||||||
| run_name : baseline | run_name : Demucs + Vtck with stride + augmentations | ||||||
|  |  | ||||||
|  | @ -1,11 +1,11 @@ | ||||||
| _target_: enhancer.models.demucs.Demucs | _target_: enhancer.models.demucs.Demucs | ||||||
| num_channels: 1 | num_channels: 1 | ||||||
| resample: 2 | resample: 4 | ||||||
| sampling_rate : 16000 | sampling_rate : 16000 | ||||||
| 
 | 
 | ||||||
| encoder_decoder: | encoder_decoder: | ||||||
|   depth: 5 |   depth: 4 | ||||||
|   initial_output_channels: 32 |   initial_output_channels: 64 | ||||||
|   kernel_size: 8 |   kernel_size: 8 | ||||||
|   stride: 4 |   stride: 4 | ||||||
|   growth_factor: 2 |   growth_factor: 2 | ||||||
|  |  | ||||||
|  | @ -1,5 +1,5 @@ | ||||||
| _target_: enhancer.models.waveunet.WaveUnet | _target_: enhancer.models.waveunet.WaveUnet | ||||||
| num_channels : 1 | num_channels : 1 | ||||||
| depth : 12 | depth : 9 | ||||||
| initial_output_channels: 24 | initial_output_channels: 24 | ||||||
| sampling_rate : 16000 | sampling_rate : 16000 | ||||||
|  |  | ||||||
|  | @ -1,5 +1,5 @@ | ||||||
| _target_: pytorch_lightning.Trainer | _target_: pytorch_lightning.Trainer | ||||||
| accelerator: auto | accelerator: gpu | ||||||
| accumulate_grad_batches: 1 | accumulate_grad_batches: 1 | ||||||
| amp_backend: native | amp_backend: native | ||||||
| auto_lr_find: True | auto_lr_find: True | ||||||
|  | @ -9,7 +9,7 @@ benchmark: False | ||||||
| check_val_every_n_epoch: 1 | check_val_every_n_epoch: 1 | ||||||
| detect_anomaly: False | detect_anomaly: False | ||||||
| deterministic: False | deterministic: False | ||||||
| devices: 1 | devices: 2 | ||||||
| enable_checkpointing: True | enable_checkpointing: True | ||||||
| enable_model_summary: True | enable_model_summary: True | ||||||
| enable_progress_bar: True | enable_progress_bar: True | ||||||
|  | @ -23,7 +23,7 @@ limit_test_batches: 1.0 | ||||||
| limit_train_batches: 1.0 | limit_train_batches: 1.0 | ||||||
| limit_val_batches: 1.0 | limit_val_batches: 1.0 | ||||||
| log_every_n_steps: 50 | log_every_n_steps: 50 | ||||||
| max_epochs: 3 | max_epochs: 200 | ||||||
| max_steps: -1 | max_steps: -1 | ||||||
| max_time: null | max_time: null | ||||||
| min_epochs: 1 | min_epochs: 1 | ||||||
|  | @ -38,7 +38,7 @@ precision: 32 | ||||||
| profiler: null | profiler: null | ||||||
| reload_dataloaders_every_n_epochs: 0 | reload_dataloaders_every_n_epochs: 0 | ||||||
| replace_sampler_ddp: True | replace_sampler_ddp: True | ||||||
| strategy: null | strategy: ddp | ||||||
| sync_batchnorm: False | sync_batchnorm: False | ||||||
| tpu_cores: null | tpu_cores: null | ||||||
| track_grad_norm: -1 | track_grad_norm: -1 | ||||||
|  |  | ||||||
|  | @ -1,14 +1,15 @@ | ||||||
| import math | import math | ||||||
| import multiprocessing | import multiprocessing | ||||||
| import os | import os | ||||||
| import random | from pathlib import Path | ||||||
| from itertools import chain, cycle |  | ||||||
| from typing import Optional | from typing import Optional | ||||||
| 
 | 
 | ||||||
|  | import numpy as np | ||||||
| import pytorch_lightning as pl | import pytorch_lightning as pl | ||||||
| import torch | import torch | ||||||
| import torch.nn.functional as F | import torch.nn.functional as F | ||||||
| from torch.utils.data import DataLoader, Dataset, IterableDataset | from torch.utils.data import DataLoader, Dataset | ||||||
|  | from torch_audiomentations import Compose | ||||||
| 
 | 
 | ||||||
| from enhancer.data.fileprocessor import Fileprocessor | from enhancer.data.fileprocessor import Fileprocessor | ||||||
| from enhancer.utils import check_files | from enhancer.utils import check_files | ||||||
|  | @ -16,13 +17,15 @@ from enhancer.utils.config import Files | ||||||
| from enhancer.utils.io import Audio | from enhancer.utils.io import Audio | ||||||
| from enhancer.utils.random import create_unique_rng | from enhancer.utils.random import create_unique_rng | ||||||
| 
 | 
 | ||||||
|  | LARGE_NUM = 2147483647 | ||||||
| 
 | 
 | ||||||
| class TrainDataset(IterableDataset): | 
 | ||||||
|  | class TrainDataset(Dataset): | ||||||
|     def __init__(self, dataset): |     def __init__(self, dataset): | ||||||
|         self.dataset = dataset |         self.dataset = dataset | ||||||
| 
 | 
 | ||||||
|     def __iter__(self): |     def __getitem__(self, idx): | ||||||
|         return self.dataset.train__iter__() |         return self.dataset.train__getitem__(idx) | ||||||
| 
 | 
 | ||||||
|     def __len__(self): |     def __len__(self): | ||||||
|         return self.dataset.train__len__() |         return self.dataset.train__len__() | ||||||
|  | @ -63,6 +66,7 @@ class TaskDataset(pl.LightningDataModule): | ||||||
|         matching_function=None, |         matching_function=None, | ||||||
|         batch_size=32, |         batch_size=32, | ||||||
|         num_workers: Optional[int] = None, |         num_workers: Optional[int] = None, | ||||||
|  |         augmentations: Optional[Compose] = None, | ||||||
|     ): |     ): | ||||||
|         super().__init__() |         super().__init__() | ||||||
| 
 | 
 | ||||||
|  | @ -82,6 +86,8 @@ class TaskDataset(pl.LightningDataModule): | ||||||
|         else: |         else: | ||||||
|             raise ValueError("valid_minutes must be greater than 0") |             raise ValueError("valid_minutes must be greater than 0") | ||||||
| 
 | 
 | ||||||
|  |         self.augmentations = augmentations | ||||||
|  | 
 | ||||||
|     def setup(self, stage: Optional[str] = None): |     def setup(self, stage: Optional[str] = None): | ||||||
|         """ |         """ | ||||||
|         prepare train/validation/test data splits |         prepare train/validation/test data splits | ||||||
|  | @ -115,16 +121,29 @@ class TaskDataset(pl.LightningDataModule): | ||||||
|     ): |     ): | ||||||
| 
 | 
 | ||||||
|         valid_minutes *= 60 |         valid_minutes *= 60 | ||||||
|         valid_min_now = 0.0 |         valid_sec_now = 0.0 | ||||||
|         valid_indices = [] |         valid_indices = [] | ||||||
|         random_indices = list(range(0, len(data))) |         all_speakers = np.unique( | ||||||
|         rng = create_unique_rng(random_state) |             [ | ||||||
|         rng.shuffle(random_indices) |                 (Path(file["clean"]).name.split("_")[0], file["duration"]) | ||||||
|         i = 0 |                 for file in data | ||||||
|         while valid_min_now <= valid_minutes: |             ] | ||||||
|             valid_indices.append(random_indices[i]) |         ) | ||||||
|             valid_min_now += data[random_indices[i]]["duration"] |         possible_indices = list(range(0, len(all_speakers))) | ||||||
|             i += 1 |         rng = create_unique_rng(len(all_speakers)) | ||||||
|  | 
 | ||||||
|  |         while valid_sec_now <= valid_minutes: | ||||||
|  |             speaker_index = rng.choice(possible_indices) | ||||||
|  |             possible_indices.remove(speaker_index) | ||||||
|  |             speaker_name = all_speakers[speaker_index] | ||||||
|  |             file_indices = [ | ||||||
|  |                 i | ||||||
|  |                 for i, file in enumerate(data) | ||||||
|  |                 if speaker_name == Path(file["clean"]).name.split("_")[0] | ||||||
|  |             ] | ||||||
|  |             for i in file_indices: | ||||||
|  |                 valid_indices.append(i) | ||||||
|  |                 valid_sec_now += data[i]["duration"] | ||||||
| 
 | 
 | ||||||
|         train_data = [ |         train_data = [ | ||||||
|             item for i, item in enumerate(data) if i not in valid_indices |             item for i, item in enumerate(data) if i not in valid_indices | ||||||
|  | @ -135,16 +154,11 @@ class TaskDataset(pl.LightningDataModule): | ||||||
|     def prepare_traindata(self, data): |     def prepare_traindata(self, data): | ||||||
|         train_data = [] |         train_data = [] | ||||||
|         for item in data: |         for item in data: | ||||||
|             samples_metadata = [] |  | ||||||
|             clean, noisy, total_dur = item.values() |             clean, noisy, total_dur = item.values() | ||||||
|             num_segments = self.get_num_segments( |             num_segments = self.get_num_segments( | ||||||
|                 total_dur, self.duration, self.stride |                 total_dur, self.duration, self.stride | ||||||
|             ) |             ) | ||||||
|             for index in range(num_segments): |             samples_metadata = ({"clean": clean, "noisy": noisy}, num_segments) | ||||||
|                 start = index * self.stride |  | ||||||
|                 samples_metadata.append( |  | ||||||
|                     ({"clean": clean, "noisy": noisy}, start) |  | ||||||
|                 ) |  | ||||||
|             train_data.append(samples_metadata) |             train_data.append(samples_metadata) | ||||||
|         return train_data |         return train_data | ||||||
| 
 | 
 | ||||||
|  | @ -166,7 +180,9 @@ class TaskDataset(pl.LightningDataModule): | ||||||
|             if total_dur < self.duration: |             if total_dur < self.duration: | ||||||
|                 metadata.append(({"clean": clean, "noisy": noisy}, 0.0)) |                 metadata.append(({"clean": clean, "noisy": noisy}, 0.0)) | ||||||
|             else: |             else: | ||||||
|                 num_segments = round(total_dur / self.duration) |                 num_segments = self.get_num_segments( | ||||||
|  |                     total_dur, self.duration, self.duration | ||||||
|  |                 ) | ||||||
|                 for index in range(num_segments): |                 for index in range(num_segments): | ||||||
|                     start_time = index * self.duration |                     start_time = index * self.duration | ||||||
|                     metadata.append( |                     metadata.append( | ||||||
|  | @ -175,31 +191,44 @@ class TaskDataset(pl.LightningDataModule): | ||||||
|         return metadata |         return metadata | ||||||
| 
 | 
 | ||||||
|     def train_collatefn(self, batch): |     def train_collatefn(self, batch): | ||||||
|         output = {"noisy": [], "clean": []} | 
 | ||||||
|  |         output = {"clean": [], "noisy": []} | ||||||
|         for item in batch: |         for item in batch: | ||||||
|             output["noisy"].append(item["noisy"]) |  | ||||||
|             output["clean"].append(item["clean"]) |             output["clean"].append(item["clean"]) | ||||||
|  |             output["noisy"].append(item["noisy"]) | ||||||
| 
 | 
 | ||||||
|         output["clean"] = torch.stack(output["clean"], dim=0) |         output["clean"] = torch.stack(output["clean"], dim=0) | ||||||
|         output["noisy"] = torch.stack(output["noisy"], dim=0) |         output["noisy"] = torch.stack(output["noisy"], dim=0) | ||||||
|  | 
 | ||||||
|  |         if self.augmentations is not None: | ||||||
|  |             noise = output["noisy"] - output["clean"] | ||||||
|  |             output["clean"] = self.augmentations( | ||||||
|  |                 output["clean"], sample_rate=self.sampling_rate | ||||||
|  |             ) | ||||||
|  |             self.augmentations.freeze_parameters() | ||||||
|  |             output["noisy"] = ( | ||||||
|  |                 self.augmentations(noise, sample_rate=self.sampling_rate) | ||||||
|  |                 + output["clean"] | ||||||
|  |             ) | ||||||
|  | 
 | ||||||
|         return output |         return output | ||||||
| 
 | 
 | ||||||
|     def worker_init_fn(self, _): |     @property | ||||||
|         worker_info = torch.utils.data.get_worker_info() |     def generator(self): | ||||||
|         dataset = worker_info.dataset |         generator = torch.Generator() | ||||||
|         worker_id = worker_info.id |         if hasattr(self, "model"): | ||||||
|         split_size = len(dataset.dataset.train_data) // worker_info.num_workers |             seed = self.model.current_epoch + LARGE_NUM | ||||||
|         dataset.data = dataset.dataset.train_data[ |         else: | ||||||
|             worker_id * split_size : (worker_id + 1) * split_size |             seed = LARGE_NUM | ||||||
|         ] |         return generator.manual_seed(seed) | ||||||
| 
 | 
 | ||||||
|     def train_dataloader(self): |     def train_dataloader(self): | ||||||
|         return DataLoader( |         return DataLoader( | ||||||
|             TrainDataset(self), |             TrainDataset(self), | ||||||
|             batch_size=None, |             batch_size=self.batch_size, | ||||||
|             num_workers=self.num_workers, |             num_workers=self.num_workers, | ||||||
|  |             generator=self.generator, | ||||||
|             collate_fn=self.train_collatefn, |             collate_fn=self.train_collatefn, | ||||||
|             worker_init_fn=self.worker_init_fn, |  | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|     def val_dataloader(self): |     def val_dataloader(self): | ||||||
|  | @ -251,11 +280,12 @@ class EnhancerDataset(TaskDataset): | ||||||
|         files: Files, |         files: Files, | ||||||
|         valid_minutes=5.0, |         valid_minutes=5.0, | ||||||
|         duration=1.0, |         duration=1.0, | ||||||
|         stride=0.5, |         stride=None, | ||||||
|         sampling_rate=48000, |         sampling_rate=48000, | ||||||
|         matching_function=None, |         matching_function=None, | ||||||
|         batch_size=32, |         batch_size=32, | ||||||
|         num_workers: Optional[int] = None, |         num_workers: Optional[int] = None, | ||||||
|  |         augmentations: Optional[Compose] = None, | ||||||
|     ): |     ): | ||||||
| 
 | 
 | ||||||
|         super().__init__( |         super().__init__( | ||||||
|  | @ -268,6 +298,7 @@ class EnhancerDataset(TaskDataset): | ||||||
|             matching_function=matching_function, |             matching_function=matching_function, | ||||||
|             batch_size=batch_size, |             batch_size=batch_size, | ||||||
|             num_workers=num_workers, |             num_workers=num_workers, | ||||||
|  |             augmentations=augmentations, | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|         self.sampling_rate = sampling_rate |         self.sampling_rate = sampling_rate | ||||||
|  | @ -280,35 +311,17 @@ class EnhancerDataset(TaskDataset): | ||||||
| 
 | 
 | ||||||
|         super().setup(stage=stage) |         super().setup(stage=stage) | ||||||
| 
 | 
 | ||||||
|     def random_sample(self, train_data): |     def train__getitem__(self, idx): | ||||||
|         return random.sample(train_data, len(train_data)) |  | ||||||
| 
 | 
 | ||||||
|     def train__iter__(self): |         for filedict, num_samples in self.train_data: | ||||||
|         rng = create_unique_rng(self.model.current_epoch) |             if idx >= num_samples: | ||||||
|         train_data = rng.sample(self.train_data, len(self.train_data)) |                 idx -= num_samples | ||||||
|         return zip( |                 continue | ||||||
|             *[ |  | ||||||
|                 self.get_stream(self.random_sample(train_data)) |  | ||||||
|                 for i in range(self.batch_size) |  | ||||||
|             ] |  | ||||||
|         ) |  | ||||||
| 
 |  | ||||||
|     def get_stream(self, data): |  | ||||||
|         return chain.from_iterable(map(self.process_data, cycle(data))) |  | ||||||
| 
 |  | ||||||
|     def process_data(self, data): |  | ||||||
|         for item in data: |  | ||||||
|             yield self.prepare_segment(*item) |  | ||||||
| 
 |  | ||||||
|     @staticmethod |  | ||||||
|     def get_num_segments(file_duration, duration, stride): |  | ||||||
| 
 |  | ||||||
|         if file_duration < duration: |  | ||||||
|             num_segments = 1 |  | ||||||
|             else: |             else: | ||||||
|             num_segments = math.ceil((file_duration - duration) / stride) + 1 |                 start = 0 | ||||||
| 
 |                 if self.duration is not None: | ||||||
|         return num_segments |                     start = idx * self.stride | ||||||
|  |                 return self.prepare_segment(filedict, start) | ||||||
| 
 | 
 | ||||||
|     def val__getitem__(self, idx): |     def val__getitem__(self, idx): | ||||||
|         return self.prepare_segment(*self._validation[idx]) |         return self.prepare_segment(*self._validation[idx]) | ||||||
|  | @ -348,7 +361,8 @@ class EnhancerDataset(TaskDataset): | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|     def train__len__(self): |     def train__len__(self): | ||||||
|         return sum([len(item) for item in self.train_data]) // (self.batch_size) |         _, num_examples = list(zip(*self.train_data)) | ||||||
|  |         return sum(num_examples) | ||||||
| 
 | 
 | ||||||
|     def val__len__(self): |     def val__len__(self): | ||||||
|         return len(self._validation) |         return len(self._validation) | ||||||
|  |  | ||||||
|  | @ -3,7 +3,7 @@ import logging | ||||||
| import numpy as np | import numpy as np | ||||||
| import torch | import torch | ||||||
| import torch.nn as nn | import torch.nn as nn | ||||||
| from pesq import pesq | from torchmetrics.audio.pesq import PerceptualEvaluationSpeechQuality | ||||||
| from torchmetrics.audio.stoi import ShortTimeObjectiveIntelligibility | from torchmetrics.audio.stoi import ShortTimeObjectiveIntelligibility | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | @ -66,7 +66,7 @@ class Si_SDR: | ||||||
|                 "Invalid reduction, valid options are sum, mean, None" |                 "Invalid reduction, valid options are sum, mean, None" | ||||||
|             ) |             ) | ||||||
|         self.higher_better = False |         self.higher_better = False | ||||||
|         self.name = "Si-SDR" |         self.name = "si-sdr" | ||||||
| 
 | 
 | ||||||
|     def __call__(self, prediction: torch.Tensor, target: torch.Tensor): |     def __call__(self, prediction: torch.Tensor, target: torch.Tensor): | ||||||
| 
 | 
 | ||||||
|  | @ -122,20 +122,16 @@ class Pesq: | ||||||
|         self.sr = sr |         self.sr = sr | ||||||
|         self.name = "pesq" |         self.name = "pesq" | ||||||
|         self.mode = mode |         self.mode = mode | ||||||
|  |         self.pesq = PerceptualEvaluationSpeechQuality( | ||||||
|  |             fs=self.sr, mode=self.mode | ||||||
|  |         ) | ||||||
| 
 | 
 | ||||||
|     def __call__(self, prediction: torch.Tensor, target: torch.Tensor): |     def __call__(self, prediction: torch.Tensor, target: torch.Tensor): | ||||||
| 
 | 
 | ||||||
|         pesq_values = [] |         pesq_values = [] | ||||||
|         for pred, target_ in zip(prediction, target): |         for pred, target_ in zip(prediction, target): | ||||||
|             try: |             try: | ||||||
|                 pesq_values.append( |                 pesq_values.append(self.pesq(pred.squeeze(), target_.squeeze())) | ||||||
|                     pesq( |  | ||||||
|                         self.sr, |  | ||||||
|                         target_.squeeze().detach().cpu().numpy(), |  | ||||||
|                         pred.squeeze().detach().cpu().numpy(), |  | ||||||
|                         self.mode, |  | ||||||
|                     ) |  | ||||||
|                 ) |  | ||||||
|             except Exception as e: |             except Exception as e: | ||||||
|                 logging.warning(f"{e} error occured while calculating PESQ") |                 logging.warning(f"{e} error occured while calculating PESQ") | ||||||
|         return torch.tensor(np.mean(pesq_values)) |         return torch.tensor(np.mean(pesq_values)) | ||||||
|  |  | ||||||
|  | @ -113,6 +113,8 @@ class Model(pl.LightningModule): | ||||||
|         if stage == "fit": |         if stage == "fit": | ||||||
|             torch.cuda.empty_cache() |             torch.cuda.empty_cache() | ||||||
|             self.dataset.setup(stage) |             self.dataset.setup(stage) | ||||||
|  |             self.dataset.model = self | ||||||
|  | 
 | ||||||
|             print( |             print( | ||||||
|                 "Total train duration", |                 "Total train duration", | ||||||
|                 self.dataset.train_dataloader().dataset.__len__() |                 self.dataset.train_dataloader().dataset.__len__() | ||||||
|  | @ -134,7 +136,6 @@ class Model(pl.LightningModule): | ||||||
|                 / 60, |                 / 60, | ||||||
|                 "minutes", |                 "minutes", | ||||||
|             ) |             ) | ||||||
|             self.dataset.model = self |  | ||||||
| 
 | 
 | ||||||
|     def train_dataloader(self): |     def train_dataloader(self): | ||||||
|         return self.dataset.train_dataloader() |         return self.dataset.train_dataloader() | ||||||
|  |  | ||||||
|  | @ -70,7 +70,7 @@ class Audio: | ||||||
| 
 | 
 | ||||||
|         if sampling_rate: |         if sampling_rate: | ||||||
|             audio = self.__class__.resample_audio( |             audio = self.__class__.resample_audio( | ||||||
|                 audio, self.sampling_rate, sampling_rate |                 audio, sampling_rate, self.sampling_rate | ||||||
|             ) |             ) | ||||||
|         if self.return_tensor: |         if self.return_tensor: | ||||||
|             return torch.tensor(audio) |             return torch.tensor(audio) | ||||||
|  |  | ||||||
|  | @ -5,7 +5,7 @@ joblib>=1.2.0 | ||||||
| librosa>=0.9.2 | librosa>=0.9.2 | ||||||
| mlflow>=1.29.0 | mlflow>=1.29.0 | ||||||
| numpy>=1.23.3 | numpy>=1.23.3 | ||||||
| git+https://github.com/ludlows/python-pesq#egg=pesq | pesq==0.0.4 | ||||||
| protobuf>=3.19.6 | protobuf>=3.19.6 | ||||||
| pystoi==0.3.3 | pystoi==0.3.3 | ||||||
| pytest-lazy-fixture>=0.6.3 | pytest-lazy-fixture>=0.6.3 | ||||||
|  | @ -14,5 +14,6 @@ scikit-learn>=1.1.2 | ||||||
| scipy>=1.9.1 | scipy>=1.9.1 | ||||||
| soundfile>=0.11.0 | soundfile>=0.11.0 | ||||||
| torch>=1.12.1 | torch>=1.12.1 | ||||||
|  | torch-audiomentations==0.11.0 | ||||||
| torchaudio>=0.12.1 | torchaudio>=0.12.1 | ||||||
| tqdm>=4.64.1 | tqdm>=4.64.1 | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue
	
	 Shahul ES
						Shahul ES