fix key error

This commit is contained in:
shahules786 2022-10-03 20:00:14 +05:30
parent 1585911767
commit 07c525ca15
1 changed files with 13 additions and 11 deletions

View File

@ -191,8 +191,8 @@ class Model(pl.LightningModule):
map_location = torch.device(DEFAULT_DEVICE) map_location = torch.device(DEFAULT_DEVICE)
loaded_checkpoint = pl_load(model_path_pl,map_location) loaded_checkpoint = pl_load(model_path_pl,map_location)
module_name = loaded_checkpoint["architecture"]["module"] module_name = loaded_checkpoint["enhancer"]["architecture"]["module"]
class_name = loaded_checkpoint["architecture"]["class"] class_name = loaded_checkpoint["enhancer"]["architecture"]["class"]
module = import_module(module_name) module = import_module(module_name)
Klass = getattr(module, class_name) Klass = getattr(module, class_name)
@ -216,7 +216,8 @@ class Model(pl.LightningModule):
batch_predictions = [] batch_predictions = []
self.eval().to(self.device) self.eval().to(self.device)
for batch_id in range(batch.shape[0],batch_size): with torch.no_grad():
for batch_id in range(0,batch.shape[0],batch_size):
batch_data = batch[batch_id:batch_id+batch_size,:,:].to(self.device) batch_data = batch[batch_id:batch_id+batch_size,:,:].to(self.device)
prediction = self(batch_data) prediction = self(batch_data)
batch_predictions.append(prediction) batch_predictions.append(prediction)
@ -232,9 +233,9 @@ class Model(pl.LightningModule):
duration:Optional[int]=None, duration:Optional[int]=None,
step_size:Optional[int]=None,): step_size:Optional[int]=None,):
model_sampling_rate = self.model.hprams("sampling_rate") model_sampling_rate = self.hparams["sampling_rate"]
if duration is None: if duration is None:
duration = self.model.hparams("duration") duration = self.hparams["duration"]
waveform = Inference.read_input(audio,sampling_rate,model_sampling_rate) waveform = Inference.read_input(audio,sampling_rate,model_sampling_rate)
waveform.to(self.device) waveform.to(self.device)
window_size = round(duration * model_sampling_rate) window_size = round(duration * model_sampling_rate)
@ -246,8 +247,9 @@ class Model(pl.LightningModule):
Inference.write_output(waveform,audio,model_sampling_rate) Inference.write_output(waveform,audio,model_sampling_rate)
else: else:
waveform = Inference.prepare_output(waveform, model_sampling_rate,
audio, sampling_rate)
return waveform return waveform
@property @property
def valid_monitor(self): def valid_monitor(self):