From e389acefa03c4c576413602ebc55c0c815f67d28 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Wed, 12 Oct 2022 17:26:10 +0530 Subject: [PATCH 1/9] commentout earlystop --- enhancer/cli/train.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/enhancer/cli/train.py b/enhancer/cli/train.py index 38300fd..de48e64 100644 --- a/enhancer/cli/train.py +++ b/enhancer/cli/train.py @@ -4,7 +4,7 @@ from types import MethodType import hydra from hydra.utils import instantiate from omegaconf import DictConfig, OmegaConf -from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint +from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.loggers import MLFlowLogger from torch.optim.lr_scheduler import ReduceLROnPlateau @@ -45,15 +45,15 @@ def main(config: DictConfig): every_n_epochs=1, ) callbacks.append(checkpoint) - early_stopping = EarlyStopping( - monitor="val_loss", - mode=direction, - min_delta=0.0, - patience=parameters.get("EarlyStopping_patience", 10), - strict=True, - verbose=False, - ) - callbacks.append(early_stopping) + # early_stopping = EarlyStopping( + # monitor="val_loss", + # mode=direction, + # min_delta=0.0, + # patience=parameters.get("EarlyStopping_patience", 10), + # strict=True, + # verbose=False, + # ) + # callbacks.append(early_stopping) def configure_optimizer(self): optimizer = instantiate( From 144b9d612803b6779bbf1a4b1fc22202f4f7b4ef Mon Sep 17 00:00:00 2001 From: shahules786 Date: Wed, 12 Oct 2022 17:26:51 +0530 Subject: [PATCH 2/9] fix logging --- enhancer/models/model.py | 90 +++++++++++++--------------------------- 1 file changed, 28 insertions(+), 62 deletions(-) diff --git a/enhancer/models/model.py b/enhancer/models/model.py index 3f74a74..8eb19a8 100644 --- a/enhancer/models/model.py +++ b/enhancer/models/model.py @@ -132,45 +132,39 @@ class Model(pl.LightningModule): mixed_waveform = batch["noisy"] target = batch["clean"] prediction = self(mixed_waveform) - loss = self.loss(prediction, target) - if ( - (self.logger) - and (self.global_step > 50) - and (self.global_step % 50 == 0) - ): - self.logger.experiment.log_metric( - run_id=self.logger.run_id, - key="train_loss", - value=loss.item(), - step=self.global_step, - ) - self.log("train_loss", loss.item()) + self.log( + "train_loss", + loss.item(), + on_epoch=True, + on_step=True, + logger=True, + prog_bar=True, + ) + return {"loss": loss} def validation_step(self, batch, batch_idx: int): + metric_dict = {} mixed_waveform = batch["noisy"] target = batch["clean"] prediction = self(mixed_waveform) - loss_val = self.loss(prediction, target) - self.log("val_loss", loss_val.item()) + for metric in self.metric: + value = metric(target, prediction) + metric_dict[f"valid_{metric.name}"] = value.item() - if ( - (self.logger) - and (self.global_step > 50) - and (self.global_step % 50 == 0) - ): - self.logger.experiment.log_metric( - run_id=self.logger.run_id, - key="val_loss", - value=loss_val.item(), - step=self.global_step, - ) + self.log_dict( + metric_dict, + on_step=True, + on_epoch=True, + prog_bar=True, + logger=True, + ) - return {"loss": loss_val} + return metric_dict def test_step(self, batch, batch_idx): @@ -183,44 +177,16 @@ class Model(pl.LightningModule): value = metric(target, prediction) metric_dict[metric.name] = value - for k, v in metric_dict.items(): - self.logger.experiment.log_metric( - run_id=self.logger.run_id, - key=k, - value=v, - step=self.global_step, - ) + self.log_dict( + metric_dict, + on_step=True, + on_epoch=True, + prog_bar=True, + logger=True, + ) return metric_dict - def training_epoch_end(self, outputs): - train_mean_loss = 0.0 - for output in outputs: - train_mean_loss += output["loss"] - train_mean_loss /= len(outputs) - - if self.logger: - self.logger.experiment.log_metric( - run_id=self.logger.run_id, - key="train_loss_epoch", - value=train_mean_loss, - step=self.current_epoch, - ) - - def validation_epoch_end(self, outputs): - valid_mean_loss = 0.0 - for output in outputs: - valid_mean_loss += output["loss"] - valid_mean_loss /= len(outputs) - - if self.logger: - self.logger.experiment.log_metric( - run_id=self.logger.run_id, - key="valid_loss_epoch", - value=valid_mean_loss, - step=self.current_epoch, - ) - def test_epoch_end(self, outputs): test_mean_metrics = defaultdict(int) From 1831dc201385a47d286acc804e9c172ed3514570 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Wed, 12 Oct 2022 17:28:04 +0530 Subject: [PATCH 3/9] config --- enhancer/cli/train_config/dataset/Vctk.yaml | 2 +- enhancer/cli/train_config/trainer/default.yaml | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/enhancer/cli/train_config/dataset/Vctk.yaml b/enhancer/cli/train_config/dataset/Vctk.yaml index 5c19320..df50da2 100644 --- a/enhancer/cli/train_config/dataset/Vctk.yaml +++ b/enhancer/cli/train_config/dataset/Vctk.yaml @@ -3,7 +3,7 @@ name : vctk root_dir : /scratch/c.sistc3/DS_10283_2791 duration : 1.0 sampling_rate: 16000 -batch_size: 64 +batch_size: 128 files: train_clean : clean_trainset_28spk_wav diff --git a/enhancer/cli/train_config/trainer/default.yaml b/enhancer/cli/train_config/trainer/default.yaml index 55101de..92ae56a 100644 --- a/enhancer/cli/train_config/trainer/default.yaml +++ b/enhancer/cli/train_config/trainer/default.yaml @@ -22,8 +22,8 @@ limit_predict_batches: 1.0 limit_test_batches: 1.0 limit_train_batches: 1.0 limit_val_batches: 1.0 -log_every_n_steps: 1 -max_epochs: 10 +log_every_n_steps: 50 +max_epochs: 3 max_steps: null max_time: null min_epochs: 1 From 6fb0aeae7fe76aa2b81a21738c348196d35012c3 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Wed, 12 Oct 2022 17:56:50 +0530 Subject: [PATCH 4/9] name loss wraper --- enhancer/loss.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/enhancer/loss.py b/enhancer/loss.py index fd1a609..cdd15a5 100644 --- a/enhancer/loss.py +++ b/enhancer/loss.py @@ -151,9 +151,11 @@ class LossWrapper(nn.Module): ) self.higher_better = direction[0] + self.name = "" for loss in losses: loss = self.validate_loss(loss) self.valid_losses.append(loss()) + self.name += f"{loss().name}_" def validate_loss(self, loss: str): if loss not in LOSS_MAP.keys(): From 226758e91f46e0ada2b84fe3d4a30869ef096798 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Wed, 12 Oct 2022 17:57:11 +0530 Subject: [PATCH 5/9] hawk --- hpc_entrypoint.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hpc_entrypoint.sh b/hpc_entrypoint.sh index 6d6a3a0..4c77127 100644 --- a/hpc_entrypoint.sh +++ b/hpc_entrypoint.sh @@ -36,4 +36,4 @@ pwd #python transcriber/tasks/embeddings/timit.py --directory /scratch/$USER/TIMIT/data/lisa/data/timit/raw/TIMIT/TEST --output ./data/test echo "Start Training..." -python cli/train.py +python enhancer/cli/train.py From ad56a160e4074153602c81697a5061a4c3c09d40 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Wed, 12 Oct 2022 17:57:42 +0530 Subject: [PATCH 6/9] train log rename --- enhancer/models/model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/enhancer/models/model.py b/enhancer/models/model.py index 8eb19a8..3ad5fa7 100644 --- a/enhancer/models/model.py +++ b/enhancer/models/model.py @@ -135,7 +135,7 @@ class Model(pl.LightningModule): loss = self.loss(prediction, target) self.log( - "train_loss", + f"train_{self.loss.name}", loss.item(), on_epoch=True, on_step=True, @@ -175,7 +175,7 @@ class Model(pl.LightningModule): for metric in self.metric: value = metric(target, prediction) - metric_dict[metric.name] = value + metric_dict[f"test_{metric.name}"] = value self.log_dict( metric_dict, From 00ef644179ca9d34438de9f2922d08e2a7ba2d80 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Wed, 12 Oct 2022 17:58:22 +0530 Subject: [PATCH 7/9] maxsteps -1 --- enhancer/cli/train_config/trainer/default.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/enhancer/cli/train_config/trainer/default.yaml b/enhancer/cli/train_config/trainer/default.yaml index 92ae56a..dfc020f 100644 --- a/enhancer/cli/train_config/trainer/default.yaml +++ b/enhancer/cli/train_config/trainer/default.yaml @@ -24,7 +24,7 @@ limit_train_batches: 1.0 limit_val_batches: 1.0 log_every_n_steps: 50 max_epochs: 3 -max_steps: null +max_steps: -1 max_time: null min_epochs: 1 min_steps: null From cb0040b508f29ccc55f68ff936dbe1cf863a3cba Mon Sep 17 00:00:00 2001 From: shahules786 Date: Wed, 12 Oct 2022 18:44:22 +0530 Subject: [PATCH 8/9] rename loss --- enhancer/models/model.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/enhancer/models/model.py b/enhancer/models/model.py index 3ad5fa7..4f055b4 100644 --- a/enhancer/models/model.py +++ b/enhancer/models/model.py @@ -135,7 +135,7 @@ class Model(pl.LightningModule): loss = self.loss(prediction, target) self.log( - f"train_{self.loss.name}", + "train_loss", loss.item(), on_epoch=True, on_step=True, @@ -152,6 +152,7 @@ class Model(pl.LightningModule): target = batch["clean"] prediction = self(mixed_waveform) + metric_dict["valid_loss"] = self.loss(target, prediction).item() for metric in self.metric: value = metric(target, prediction) metric_dict[f"valid_{metric.name}"] = value.item() From 2e58091543b6bb76d94811b208e67369d6ba21aa Mon Sep 17 00:00:00 2001 From: shahules786 Date: Wed, 12 Oct 2022 18:45:07 +0530 Subject: [PATCH 9/9] rename loss --- enhancer/cli/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/enhancer/cli/train.py b/enhancer/cli/train.py index de48e64..49f7b3b 100644 --- a/enhancer/cli/train.py +++ b/enhancer/cli/train.py @@ -39,7 +39,7 @@ def main(config: DictConfig): checkpoint = ModelCheckpoint( dirpath="./model", filename=f"model_{JOB_ID}", - monitor="val_loss", + monitor="valid_loss", verbose=False, mode=direction, every_n_epochs=1,