From 07c525ca150fc3bf81b3d7272689cee355243404 Mon Sep 17 00:00:00 2001 From: shahules786 Date: Mon, 3 Oct 2022 20:00:14 +0530 Subject: [PATCH] fix key error --- enhancer/models/model.py | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/enhancer/models/model.py b/enhancer/models/model.py index b1bdd86..5827301 100644 --- a/enhancer/models/model.py +++ b/enhancer/models/model.py @@ -191,8 +191,8 @@ class Model(pl.LightningModule): map_location = torch.device(DEFAULT_DEVICE) loaded_checkpoint = pl_load(model_path_pl,map_location) - module_name = loaded_checkpoint["architecture"]["module"] - class_name = loaded_checkpoint["architecture"]["class"] + module_name = loaded_checkpoint["enhancer"]["architecture"]["module"] + class_name = loaded_checkpoint["enhancer"]["architecture"]["class"] module = import_module(module_name) Klass = getattr(module, class_name) @@ -216,11 +216,12 @@ class Model(pl.LightningModule): batch_predictions = [] self.eval().to(self.device) - for batch_id in range(batch.shape[0],batch_size): - batch_data = batch[batch_id:batch_id+batch_size,:,:].to(self.device) - prediction = self(batch_data) - batch_predictions.append(prediction) - + with torch.no_grad(): + for batch_id in range(0,batch.shape[0],batch_size): + batch_data = batch[batch_id:batch_id+batch_size,:,:].to(self.device) + prediction = self(batch_data) + batch_predictions.append(prediction) + return torch.vstack(batch_predictions) def enhance( @@ -232,9 +233,9 @@ class Model(pl.LightningModule): duration:Optional[int]=None, step_size:Optional[int]=None,): - model_sampling_rate = self.model.hprams("sampling_rate") + model_sampling_rate = self.hparams["sampling_rate"] if duration is None: - duration = self.model.hparams("duration") + duration = self.hparams["duration"] waveform = Inference.read_input(audio,sampling_rate,model_sampling_rate) waveform.to(self.device) window_size = round(duration * model_sampling_rate) @@ -246,8 +247,9 @@ class Model(pl.LightningModule): Inference.write_output(waveform,audio,model_sampling_rate) else: - return waveform - + waveform = Inference.prepare_output(waveform, model_sampling_rate, + audio, sampling_rate) + return waveform @property def valid_monitor(self):