merge fix

This commit is contained in:
shahules786 2022-10-05 11:19:29 +05:30
commit c256b5478e
8 changed files with 34 additions and 22 deletions

View File

@ -1,3 +1,4 @@
from genericpath import isfile
import os import os
from types import MethodType from types import MethodType
import hydra import hydra
@ -7,7 +8,7 @@ from torch.optim.lr_scheduler import ReduceLROnPlateau
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.loggers import MLFlowLogger from pytorch_lightning.loggers import MLFlowLogger
os.environ["HYDRA_FULL_ERROR"] = "1" os.environ["HYDRA_FULL_ERROR"] = "1"
JOB_ID = os.environ.get("SLURM_JOBID") JOB_ID = os.environ.get("SLURM_JOBID","0")
@hydra.main(config_path="train_config",config_name="config") @hydra.main(config_path="train_config",config_name="config")
def main(config: DictConfig): def main(config: DictConfig):
@ -55,8 +56,10 @@ def main(config: DictConfig):
trainer = instantiate(config.trainer,logger=logger,callbacks=callbacks) trainer = instantiate(config.trainer,logger=logger,callbacks=callbacks)
trainer.fit(model) trainer.fit(model)
if os.path.exists("./model/"):
logger.experiment.log_artifact(logger.run_id,f"./model/.*") saved_location = os.path.join(trainer.default_root_dir,"model",f"model_{JOB_ID}.ckpt")
if os.path.isfile(saved_location):
logger.experiment.log_artifact(logger.run_id,saved_location)

View File

@ -1,5 +1,5 @@
defaults: defaults:
- model : Demucs - model : WaveUnet
- dataset : Vctk - dataset : Vctk
- optimizer : Adam - optimizer : Adam
- hyperparameters : default - hyperparameters : default

View File

@ -3,12 +3,12 @@ name : vctk
root_dir : /scratch/c.sistc3/DS_10283_2791 root_dir : /scratch/c.sistc3/DS_10283_2791
duration : 1.0 duration : 1.0
sampling_rate: 16000 sampling_rate: 16000
batch_size: 8 batch_size: 64
files: files:
train_clean : clean_trainset_56spk_wav train_clean : clean_trainset_28spk_wav
test_clean : clean_testset_wav test_clean : clean_testset_wav
train_noisy : noisy_trainset_56spk_wav train_noisy : noisy_trainset_28spk_wav
test_noisy : noisy_testset_wav test_noisy : noisy_testset_wav

View File

@ -1,7 +1,6 @@
loss : mse loss : mse
metric : mae metric : mae
lr : 0.0001 lr : 0.0001
num_epochs : 100
ReduceLr_patience : 5 ReduceLr_patience : 5
ReduceLr_factor : 0.1 ReduceLr_factor : 0.1
min_lr : 0.000001 min_lr : 0.000001

View File

@ -22,8 +22,8 @@ limit_predict_batches: 1.0
limit_test_batches: 1.0 limit_test_batches: 1.0
limit_train_batches: 1.0 limit_train_batches: 1.0
limit_val_batches: 1.0 limit_val_batches: 1.0
log_every_n_steps: 10 log_every_n_steps: 1
max_epochs: 100 max_epochs: 10
max_steps: null max_steps: null
max_time: null max_time: null
min_epochs: 1 min_epochs: 1

View File

@ -27,6 +27,8 @@ class Inference:
if isinstance(audio, (np.ndarray, torch.Tensor)): if isinstance(audio, (np.ndarray, torch.Tensor)):
assert sr is not None, "Invalid sampling rate!" assert sr is not None, "Invalid sampling rate!"
if len(audio.shape) == 1:
audio = audio.reshape(1,-1)
if isinstance(audio, str): if isinstance(audio, str):
audio = Path(audio) audio = Path(audio)
@ -103,6 +105,8 @@ class Inference:
window = get_window(window=window, Nx=data.shape[-1]) window = get_window(window=window, Nx=data.shape[-1])
window = torch.from_numpy(window).to(data.device) window = torch.from_numpy(window).to(data.device)
data *= window data *= window
step_size = window_size//2 if step_size is None else step_size
data = data.permute(1, 2, 0) data = data.permute(1, 2, 0)
data = F.fold( data = F.fold(
@ -129,6 +133,9 @@ class Inference:
if isinstance(filename, str): if isinstance(filename, str):
filename = Path(filename) filename = Path(filename)
parent, name = filename.parent, "cleaned_"+filename.name
filename = parent/Path(name)
if filename.is_file(): if filename.is_file():
raise FileExistsError(f"file {filename} already exists") raise FileExistsError(f"file {filename} already exists")
else: else:

View File

@ -1,2 +1,3 @@
from enhancer.models.demucs import Demucs from enhancer.models.demucs import Demucs
from enhancer.models.waveunet import WaveUnet from enhancer.models.waveunet import WaveUnet
from enhancer.models.model import Model

View File

@ -191,8 +191,8 @@ class Model(pl.LightningModule):
map_location = torch.device(DEFAULT_DEVICE) map_location = torch.device(DEFAULT_DEVICE)
loaded_checkpoint = pl_load(model_path_pl,map_location) loaded_checkpoint = pl_load(model_path_pl,map_location)
module_name = loaded_checkpoint["architecture"]["module"] module_name = loaded_checkpoint["enhancer"]["architecture"]["module"]
class_name = loaded_checkpoint["architecture"]["class"] class_name = loaded_checkpoint["enhancer"]["architecture"]["class"]
module = import_module(module_name) module = import_module(module_name)
Klass = getattr(module, class_name) Klass = getattr(module, class_name)
@ -216,11 +216,12 @@ class Model(pl.LightningModule):
batch_predictions = [] batch_predictions = []
self.eval().to(self.device) self.eval().to(self.device)
for batch_id in range(batch.shape[0],batch_size): with torch.no_grad():
batch_data = batch[batch_id:batch_id+batch_size,:,:].to(self.device) for batch_id in range(0,batch.shape[0],batch_size):
prediction = self(batch_data) batch_data = batch[batch_id:batch_id+batch_size,:,:].to(self.device)
batch_predictions.append(prediction) prediction = self(batch_data)
batch_predictions.append(prediction)
return torch.vstack(batch_predictions) return torch.vstack(batch_predictions)
def enhance( def enhance(
@ -232,9 +233,9 @@ class Model(pl.LightningModule):
duration:Optional[int]=None, duration:Optional[int]=None,
step_size:Optional[int]=None,): step_size:Optional[int]=None,):
model_sampling_rate = self.model.hprams("sampling_rate") model_sampling_rate = self.hparams["sampling_rate"]
if duration is None: if duration is None:
duration = self.model.hparams("duration") duration = self.hparams["duration"]
waveform = Inference.read_input(audio,sampling_rate,model_sampling_rate) waveform = Inference.read_input(audio,sampling_rate,model_sampling_rate)
waveform.to(self.device) waveform.to(self.device)
window_size = round(duration * model_sampling_rate) window_size = round(duration * model_sampling_rate)
@ -246,8 +247,9 @@ class Model(pl.LightningModule):
Inference.write_output(waveform,audio,model_sampling_rate) Inference.write_output(waveform,audio,model_sampling_rate)
else: else:
return waveform waveform = Inference.prepare_output(waveform, model_sampling_rate,
audio, sampling_rate)
return waveform
@property @property
def valid_monitor(self): def valid_monitor(self):