fix key error

This commit is contained in:
shahules786 2022-10-03 20:00:14 +05:30
parent 1585911767
commit 07c525ca15
1 changed files with 13 additions and 11 deletions

View File

@ -191,8 +191,8 @@ class Model(pl.LightningModule):
map_location = torch.device(DEFAULT_DEVICE)
loaded_checkpoint = pl_load(model_path_pl,map_location)
module_name = loaded_checkpoint["architecture"]["module"]
class_name = loaded_checkpoint["architecture"]["class"]
module_name = loaded_checkpoint["enhancer"]["architecture"]["module"]
class_name = loaded_checkpoint["enhancer"]["architecture"]["class"]
module = import_module(module_name)
Klass = getattr(module, class_name)
@ -216,7 +216,8 @@ class Model(pl.LightningModule):
batch_predictions = []
self.eval().to(self.device)
for batch_id in range(batch.shape[0],batch_size):
with torch.no_grad():
for batch_id in range(0,batch.shape[0],batch_size):
batch_data = batch[batch_id:batch_id+batch_size,:,:].to(self.device)
prediction = self(batch_data)
batch_predictions.append(prediction)
@ -232,9 +233,9 @@ class Model(pl.LightningModule):
duration:Optional[int]=None,
step_size:Optional[int]=None,):
model_sampling_rate = self.model.hprams("sampling_rate")
model_sampling_rate = self.hparams["sampling_rate"]
if duration is None:
duration = self.model.hparams("duration")
duration = self.hparams["duration"]
waveform = Inference.read_input(audio,sampling_rate,model_sampling_rate)
waveform.to(self.device)
window_size = round(duration * model_sampling_rate)
@ -246,8 +247,9 @@ class Model(pl.LightningModule):
Inference.write_output(waveform,audio,model_sampling_rate)
else:
waveform = Inference.prepare_output(waveform, model_sampling_rate,
audio, sampling_rate)
return waveform
@property
def valid_monitor(self):