diff --git a/cli/train.py b/cli/train.py index 8a4beaf..dee3d2e 100644 --- a/cli/train.py +++ b/cli/train.py @@ -1,3 +1,4 @@ +from genericpath import isfile import os from types import MethodType import hydra @@ -7,7 +8,7 @@ from torch.optim.lr_scheduler import ReduceLROnPlateau from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping from pytorch_lightning.loggers import MLFlowLogger os.environ["HYDRA_FULL_ERROR"] = "1" -JOB_ID = os.environ.get("SLURM_JOBID") +JOB_ID = os.environ.get("SLURM_JOBID","0") @hydra.main(config_path="train_config",config_name="config") def main(config: DictConfig): @@ -55,8 +56,10 @@ def main(config: DictConfig): trainer = instantiate(config.trainer,logger=logger,callbacks=callbacks) trainer.fit(model) - if os.path.exists("./model/"): - logger.experiment.log_artifact(logger.run_id,f"./model/.*") + + saved_location = os.path.join(trainer.default_root_dir,"model",f"model_{JOB_ID}.ckpt") + if os.path.isfile(saved_location): + logger.experiment.log_artifact(logger.run_id,saved_location) diff --git a/cli/train_config/config.yaml b/cli/train_config/config.yaml index 6b5d98e..61551bd 100644 --- a/cli/train_config/config.yaml +++ b/cli/train_config/config.yaml @@ -1,5 +1,5 @@ defaults: - - model : Demucs + - model : WaveUnet - dataset : Vctk - optimizer : Adam - hyperparameters : default diff --git a/cli/train_config/dataset/Vctk.yaml b/cli/train_config/dataset/Vctk.yaml index d1c8646..129d9a8 100644 --- a/cli/train_config/dataset/Vctk.yaml +++ b/cli/train_config/dataset/Vctk.yaml @@ -3,12 +3,12 @@ name : vctk root_dir : /scratch/c.sistc3/DS_10283_2791 duration : 1.0 sampling_rate: 16000 -batch_size: 8 +batch_size: 64 files: - train_clean : clean_trainset_56spk_wav + train_clean : clean_trainset_28spk_wav test_clean : clean_testset_wav - train_noisy : noisy_trainset_56spk_wav + train_noisy : noisy_trainset_28spk_wav test_noisy : noisy_testset_wav diff --git a/cli/train_config/hyperparameters/default.yaml b/cli/train_config/hyperparameters/default.yaml index 49a7f80..82ac3c2 100644 --- a/cli/train_config/hyperparameters/default.yaml +++ b/cli/train_config/hyperparameters/default.yaml @@ -1,7 +1,6 @@ loss : mse metric : mae lr : 0.0001 -num_epochs : 100 ReduceLr_patience : 5 ReduceLr_factor : 0.1 min_lr : 0.000001 diff --git a/cli/train_config/trainer/default.yaml b/cli/train_config/trainer/default.yaml index ab4e273..55101de 100644 --- a/cli/train_config/trainer/default.yaml +++ b/cli/train_config/trainer/default.yaml @@ -22,8 +22,8 @@ limit_predict_batches: 1.0 limit_test_batches: 1.0 limit_train_batches: 1.0 limit_val_batches: 1.0 -log_every_n_steps: 10 -max_epochs: 100 +log_every_n_steps: 1 +max_epochs: 10 max_steps: null max_time: null min_epochs: 1 diff --git a/enhancer/inference.py b/enhancer/inference.py index 27a9385..838ef5f 100644 --- a/enhancer/inference.py +++ b/enhancer/inference.py @@ -27,6 +27,8 @@ class Inference: if isinstance(audio, (np.ndarray, torch.Tensor)): assert sr is not None, "Invalid sampling rate!" + if len(audio.shape) == 1: + audio = audio.reshape(1,-1) if isinstance(audio, str): audio = Path(audio) @@ -103,6 +105,8 @@ class Inference: window = get_window(window=window, Nx=data.shape[-1]) window = torch.from_numpy(window).to(data.device) data *= window + step_size = window_size//2 if step_size is None else step_size + data = data.permute(1, 2, 0) data = F.fold( @@ -129,6 +133,9 @@ class Inference: if isinstance(filename, str): filename = Path(filename) + + parent, name = filename.parent, "cleaned_"+filename.name + filename = parent/Path(name) if filename.is_file(): raise FileExistsError(f"file {filename} already exists") else: diff --git a/enhancer/models/__init__.py b/enhancer/models/__init__.py index 0d21337..534a608 100644 --- a/enhancer/models/__init__.py +++ b/enhancer/models/__init__.py @@ -1,2 +1,3 @@ from enhancer.models.demucs import Demucs -from enhancer.models.waveunet import WaveUnet \ No newline at end of file +from enhancer.models.waveunet import WaveUnet +from enhancer.models.model import Model \ No newline at end of file diff --git a/enhancer/models/model.py b/enhancer/models/model.py index b1bdd86..de2edab 100644 --- a/enhancer/models/model.py +++ b/enhancer/models/model.py @@ -191,8 +191,8 @@ class Model(pl.LightningModule): map_location = torch.device(DEFAULT_DEVICE) loaded_checkpoint = pl_load(model_path_pl,map_location) - module_name = loaded_checkpoint["architecture"]["module"] - class_name = loaded_checkpoint["architecture"]["class"] + module_name = loaded_checkpoint["enhancer"]["architecture"]["module"] + class_name = loaded_checkpoint["enhancer"]["architecture"]["class"] module = import_module(module_name) Klass = getattr(module, class_name) @@ -216,11 +216,12 @@ class Model(pl.LightningModule): batch_predictions = [] self.eval().to(self.device) - for batch_id in range(batch.shape[0],batch_size): - batch_data = batch[batch_id:batch_id+batch_size,:,:].to(self.device) - prediction = self(batch_data) - batch_predictions.append(prediction) - + with torch.no_grad(): + for batch_id in range(0,batch.shape[0],batch_size): + batch_data = batch[batch_id:batch_id+batch_size,:,:].to(self.device) + prediction = self(batch_data) + batch_predictions.append(prediction) + return torch.vstack(batch_predictions) def enhance( @@ -232,9 +233,9 @@ class Model(pl.LightningModule): duration:Optional[int]=None, step_size:Optional[int]=None,): - model_sampling_rate = self.model.hprams("sampling_rate") + model_sampling_rate = self.hparams["sampling_rate"] if duration is None: - duration = self.model.hparams("duration") + duration = self.hparams["duration"] waveform = Inference.read_input(audio,sampling_rate,model_sampling_rate) waveform.to(self.device) window_size = round(duration * model_sampling_rate) @@ -246,8 +247,9 @@ class Model(pl.LightningModule): Inference.write_output(waveform,audio,model_sampling_rate) else: - return waveform - + waveform = Inference.prepare_output(waveform, model_sampling_rate, + audio, sampling_rate) + return waveform @property def valid_monitor(self):