fix key error
This commit is contained in:
parent
1585911767
commit
07c525ca15
|
|
@ -191,8 +191,8 @@ class Model(pl.LightningModule):
|
|||
map_location = torch.device(DEFAULT_DEVICE)
|
||||
|
||||
loaded_checkpoint = pl_load(model_path_pl,map_location)
|
||||
module_name = loaded_checkpoint["architecture"]["module"]
|
||||
class_name = loaded_checkpoint["architecture"]["class"]
|
||||
module_name = loaded_checkpoint["enhancer"]["architecture"]["module"]
|
||||
class_name = loaded_checkpoint["enhancer"]["architecture"]["class"]
|
||||
module = import_module(module_name)
|
||||
Klass = getattr(module, class_name)
|
||||
|
||||
|
|
@ -216,7 +216,8 @@ class Model(pl.LightningModule):
|
|||
batch_predictions = []
|
||||
self.eval().to(self.device)
|
||||
|
||||
for batch_id in range(batch.shape[0],batch_size):
|
||||
with torch.no_grad():
|
||||
for batch_id in range(0,batch.shape[0],batch_size):
|
||||
batch_data = batch[batch_id:batch_id+batch_size,:,:].to(self.device)
|
||||
prediction = self(batch_data)
|
||||
batch_predictions.append(prediction)
|
||||
|
|
@ -232,9 +233,9 @@ class Model(pl.LightningModule):
|
|||
duration:Optional[int]=None,
|
||||
step_size:Optional[int]=None,):
|
||||
|
||||
model_sampling_rate = self.model.hprams("sampling_rate")
|
||||
model_sampling_rate = self.hparams["sampling_rate"]
|
||||
if duration is None:
|
||||
duration = self.model.hparams("duration")
|
||||
duration = self.hparams["duration"]
|
||||
waveform = Inference.read_input(audio,sampling_rate,model_sampling_rate)
|
||||
waveform.to(self.device)
|
||||
window_size = round(duration * model_sampling_rate)
|
||||
|
|
@ -246,8 +247,9 @@ class Model(pl.LightningModule):
|
|||
Inference.write_output(waveform,audio,model_sampling_rate)
|
||||
|
||||
else:
|
||||
waveform = Inference.prepare_output(waveform, model_sampling_rate,
|
||||
audio, sampling_rate)
|
||||
return waveform
|
||||
|
||||
@property
|
||||
def valid_monitor(self):
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue