Merge pull request #3 from shahules786/dev-hawk

fix inference issues
This commit is contained in:
Shahul ES 2022-10-03 20:03:21 +05:30 committed by GitHub
commit a4c1769efe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 43 additions and 23 deletions

View File

@ -1,3 +1,4 @@
from genericpath import isfile
import os
from types import MethodType
import hydra
@ -7,7 +8,7 @@ from torch.optim.lr_scheduler import ReduceLROnPlateau
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.loggers import MLFlowLogger
os.environ["HYDRA_FULL_ERROR"] = "1"
JOB_ID = os.environ.get("SLURM_JOBID")
JOB_ID = os.environ.get("SLURM_JOBID","0")
@hydra.main(config_path="train_config",config_name="config")
def main(config: DictConfig):
@ -55,8 +56,10 @@ def main(config: DictConfig):
trainer = instantiate(config.trainer,logger=logger,callbacks=callbacks)
trainer.fit(model)
if os.path.exists("./model/"):
logger.experiment.log_artifact(logger.run_id,f"./model/.*")
saved_location = os.path.join(trainer.default_root_dir,"model",f"model_{JOB_ID}.ckpt")
if os.path.isfile(saved_location):
logger.experiment.log_artifact(logger.run_id,saved_location)

View File

@ -1,5 +1,5 @@
defaults:
- model : Demucs
- model : WaveUnet
- dataset : Vctk
- optimizer : Adam
- hyperparameters : default

View File

@ -3,12 +3,12 @@ name : vctk
root_dir : /scratch/c.sistc3/DS_10283_2791
duration : 1.0
sampling_rate: 16000
batch_size: 8
batch_size: 64
files:
train_clean : clean_trainset_56spk_wav
train_clean : clean_trainset_28spk_wav
test_clean : clean_testset_wav
train_noisy : noisy_trainset_56spk_wav
train_noisy : noisy_trainset_28spk_wav
test_noisy : noisy_testset_wav

View File

@ -1,7 +1,6 @@
loss : mse
metric : mae
lr : 0.0001
num_epochs : 100
ReduceLr_patience : 5
ReduceLr_factor : 0.1
min_lr : 0.000001

View File

@ -22,8 +22,8 @@ limit_predict_batches: 1.0
limit_test_batches: 1.0
limit_train_batches: 1.0
limit_val_batches: 1.0
log_every_n_steps: 10
max_epochs: 100
log_every_n_steps: 1
max_epochs: 10
max_steps: null
max_time: null
min_epochs: 1

View File

@ -18,6 +18,8 @@ class Inference:
if isinstance(audio,(np.ndarray,torch.Tensor)):
assert sr is not None, "Invalid sampling rate!"
if len(audio.shape) == 1:
audio = audio.reshape(1,-1)
if isinstance(audio,str):
audio = Path(audio)
@ -65,6 +67,8 @@ class Inference:
window = get_window(window=window,Nx=data.shape[-1])
window = torch.from_numpy(window).to(data.device)
data *= window
step_size = window_size//2 if step_size is None else step_size
data = data.permute(1,2,0)
data = F.fold(data,
@ -85,6 +89,17 @@ class Inference:
else:
wavfile.write(filename,rate=sr,data=waveform.detach().cpu())
@staticmethod
def prepare_output(waveform:torch.Tensor, model_sampling_rate:int,
audio:Union[str,np.ndarray,torch.Tensor], sampling_rate:Optional[int]
):
if isinstance(audio,np.ndarray):
waveform = waveform.detach().cpu().numpy()
if sampling_rate!=None:
waveform = Audio.resample_audio(waveform, sr=model_sampling_rate, target_sr=sampling_rate)
return waveform

View File

@ -1,2 +1,3 @@
from enhancer.models.demucs import Demucs
from enhancer.models.waveunet import WaveUnet
from enhancer.models.model import Model

View File

@ -191,8 +191,8 @@ class Model(pl.LightningModule):
map_location = torch.device(DEFAULT_DEVICE)
loaded_checkpoint = pl_load(model_path_pl,map_location)
module_name = loaded_checkpoint["architecture"]["module"]
class_name = loaded_checkpoint["architecture"]["class"]
module_name = loaded_checkpoint["enhancer"]["architecture"]["module"]
class_name = loaded_checkpoint["enhancer"]["architecture"]["class"]
module = import_module(module_name)
Klass = getattr(module, class_name)
@ -216,7 +216,8 @@ class Model(pl.LightningModule):
batch_predictions = []
self.eval().to(self.device)
for batch_id in range(batch.shape[0],batch_size):
with torch.no_grad():
for batch_id in range(0,batch.shape[0],batch_size):
batch_data = batch[batch_id:batch_id+batch_size,:,:].to(self.device)
prediction = self(batch_data)
batch_predictions.append(prediction)
@ -232,9 +233,9 @@ class Model(pl.LightningModule):
duration:Optional[int]=None,
step_size:Optional[int]=None,):
model_sampling_rate = self.model.hprams("sampling_rate")
model_sampling_rate = self.hparams["sampling_rate"]
if duration is None:
duration = self.model.hparams("duration")
duration = self.hparams["duration"]
waveform = Inference.read_input(audio,sampling_rate,model_sampling_rate)
waveform.to(self.device)
window_size = round(duration * model_sampling_rate)
@ -246,8 +247,9 @@ class Model(pl.LightningModule):
Inference.write_output(waveform,audio,model_sampling_rate)
else:
waveform = Inference.prepare_output(waveform, model_sampling_rate,
audio, sampling_rate)
return waveform
@property
def valid_monitor(self):