fix key error
This commit is contained in:
parent
1585911767
commit
07c525ca15
|
|
@ -191,8 +191,8 @@ class Model(pl.LightningModule):
|
||||||
map_location = torch.device(DEFAULT_DEVICE)
|
map_location = torch.device(DEFAULT_DEVICE)
|
||||||
|
|
||||||
loaded_checkpoint = pl_load(model_path_pl,map_location)
|
loaded_checkpoint = pl_load(model_path_pl,map_location)
|
||||||
module_name = loaded_checkpoint["architecture"]["module"]
|
module_name = loaded_checkpoint["enhancer"]["architecture"]["module"]
|
||||||
class_name = loaded_checkpoint["architecture"]["class"]
|
class_name = loaded_checkpoint["enhancer"]["architecture"]["class"]
|
||||||
module = import_module(module_name)
|
module = import_module(module_name)
|
||||||
Klass = getattr(module, class_name)
|
Klass = getattr(module, class_name)
|
||||||
|
|
||||||
|
|
@ -216,7 +216,8 @@ class Model(pl.LightningModule):
|
||||||
batch_predictions = []
|
batch_predictions = []
|
||||||
self.eval().to(self.device)
|
self.eval().to(self.device)
|
||||||
|
|
||||||
for batch_id in range(batch.shape[0],batch_size):
|
with torch.no_grad():
|
||||||
|
for batch_id in range(0,batch.shape[0],batch_size):
|
||||||
batch_data = batch[batch_id:batch_id+batch_size,:,:].to(self.device)
|
batch_data = batch[batch_id:batch_id+batch_size,:,:].to(self.device)
|
||||||
prediction = self(batch_data)
|
prediction = self(batch_data)
|
||||||
batch_predictions.append(prediction)
|
batch_predictions.append(prediction)
|
||||||
|
|
@ -232,9 +233,9 @@ class Model(pl.LightningModule):
|
||||||
duration:Optional[int]=None,
|
duration:Optional[int]=None,
|
||||||
step_size:Optional[int]=None,):
|
step_size:Optional[int]=None,):
|
||||||
|
|
||||||
model_sampling_rate = self.model.hprams("sampling_rate")
|
model_sampling_rate = self.hparams["sampling_rate"]
|
||||||
if duration is None:
|
if duration is None:
|
||||||
duration = self.model.hparams("duration")
|
duration = self.hparams["duration"]
|
||||||
waveform = Inference.read_input(audio,sampling_rate,model_sampling_rate)
|
waveform = Inference.read_input(audio,sampling_rate,model_sampling_rate)
|
||||||
waveform.to(self.device)
|
waveform.to(self.device)
|
||||||
window_size = round(duration * model_sampling_rate)
|
window_size = round(duration * model_sampling_rate)
|
||||||
|
|
@ -246,8 +247,9 @@ class Model(pl.LightningModule):
|
||||||
Inference.write_output(waveform,audio,model_sampling_rate)
|
Inference.write_output(waveform,audio,model_sampling_rate)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
waveform = Inference.prepare_output(waveform, model_sampling_rate,
|
||||||
|
audio, sampling_rate)
|
||||||
return waveform
|
return waveform
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def valid_monitor(self):
|
def valid_monitor(self):
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue