merge fix
This commit is contained in:
commit
c256b5478e
|
|
@ -1,3 +1,4 @@
|
||||||
|
from genericpath import isfile
|
||||||
import os
|
import os
|
||||||
from types import MethodType
|
from types import MethodType
|
||||||
import hydra
|
import hydra
|
||||||
|
|
@ -7,7 +8,7 @@ from torch.optim.lr_scheduler import ReduceLROnPlateau
|
||||||
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
|
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
|
||||||
from pytorch_lightning.loggers import MLFlowLogger
|
from pytorch_lightning.loggers import MLFlowLogger
|
||||||
os.environ["HYDRA_FULL_ERROR"] = "1"
|
os.environ["HYDRA_FULL_ERROR"] = "1"
|
||||||
JOB_ID = os.environ.get("SLURM_JOBID")
|
JOB_ID = os.environ.get("SLURM_JOBID","0")
|
||||||
|
|
||||||
@hydra.main(config_path="train_config",config_name="config")
|
@hydra.main(config_path="train_config",config_name="config")
|
||||||
def main(config: DictConfig):
|
def main(config: DictConfig):
|
||||||
|
|
@ -55,8 +56,10 @@ def main(config: DictConfig):
|
||||||
|
|
||||||
trainer = instantiate(config.trainer,logger=logger,callbacks=callbacks)
|
trainer = instantiate(config.trainer,logger=logger,callbacks=callbacks)
|
||||||
trainer.fit(model)
|
trainer.fit(model)
|
||||||
if os.path.exists("./model/"):
|
|
||||||
logger.experiment.log_artifact(logger.run_id,f"./model/.*")
|
saved_location = os.path.join(trainer.default_root_dir,"model",f"model_{JOB_ID}.ckpt")
|
||||||
|
if os.path.isfile(saved_location):
|
||||||
|
logger.experiment.log_artifact(logger.run_id,saved_location)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
defaults:
|
defaults:
|
||||||
- model : Demucs
|
- model : WaveUnet
|
||||||
- dataset : Vctk
|
- dataset : Vctk
|
||||||
- optimizer : Adam
|
- optimizer : Adam
|
||||||
- hyperparameters : default
|
- hyperparameters : default
|
||||||
|
|
|
||||||
|
|
@ -3,12 +3,12 @@ name : vctk
|
||||||
root_dir : /scratch/c.sistc3/DS_10283_2791
|
root_dir : /scratch/c.sistc3/DS_10283_2791
|
||||||
duration : 1.0
|
duration : 1.0
|
||||||
sampling_rate: 16000
|
sampling_rate: 16000
|
||||||
batch_size: 8
|
batch_size: 64
|
||||||
|
|
||||||
files:
|
files:
|
||||||
train_clean : clean_trainset_56spk_wav
|
train_clean : clean_trainset_28spk_wav
|
||||||
test_clean : clean_testset_wav
|
test_clean : clean_testset_wav
|
||||||
train_noisy : noisy_trainset_56spk_wav
|
train_noisy : noisy_trainset_28spk_wav
|
||||||
test_noisy : noisy_testset_wav
|
test_noisy : noisy_testset_wav
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,6 @@
|
||||||
loss : mse
|
loss : mse
|
||||||
metric : mae
|
metric : mae
|
||||||
lr : 0.0001
|
lr : 0.0001
|
||||||
num_epochs : 100
|
|
||||||
ReduceLr_patience : 5
|
ReduceLr_patience : 5
|
||||||
ReduceLr_factor : 0.1
|
ReduceLr_factor : 0.1
|
||||||
min_lr : 0.000001
|
min_lr : 0.000001
|
||||||
|
|
|
||||||
|
|
@ -22,8 +22,8 @@ limit_predict_batches: 1.0
|
||||||
limit_test_batches: 1.0
|
limit_test_batches: 1.0
|
||||||
limit_train_batches: 1.0
|
limit_train_batches: 1.0
|
||||||
limit_val_batches: 1.0
|
limit_val_batches: 1.0
|
||||||
log_every_n_steps: 10
|
log_every_n_steps: 1
|
||||||
max_epochs: 100
|
max_epochs: 10
|
||||||
max_steps: null
|
max_steps: null
|
||||||
max_time: null
|
max_time: null
|
||||||
min_epochs: 1
|
min_epochs: 1
|
||||||
|
|
|
||||||
|
|
@ -27,6 +27,8 @@ class Inference:
|
||||||
|
|
||||||
if isinstance(audio, (np.ndarray, torch.Tensor)):
|
if isinstance(audio, (np.ndarray, torch.Tensor)):
|
||||||
assert sr is not None, "Invalid sampling rate!"
|
assert sr is not None, "Invalid sampling rate!"
|
||||||
|
if len(audio.shape) == 1:
|
||||||
|
audio = audio.reshape(1,-1)
|
||||||
|
|
||||||
if isinstance(audio, str):
|
if isinstance(audio, str):
|
||||||
audio = Path(audio)
|
audio = Path(audio)
|
||||||
|
|
@ -103,6 +105,8 @@ class Inference:
|
||||||
window = get_window(window=window, Nx=data.shape[-1])
|
window = get_window(window=window, Nx=data.shape[-1])
|
||||||
window = torch.from_numpy(window).to(data.device)
|
window = torch.from_numpy(window).to(data.device)
|
||||||
data *= window
|
data *= window
|
||||||
|
step_size = window_size//2 if step_size is None else step_size
|
||||||
|
|
||||||
|
|
||||||
data = data.permute(1, 2, 0)
|
data = data.permute(1, 2, 0)
|
||||||
data = F.fold(
|
data = F.fold(
|
||||||
|
|
@ -129,6 +133,9 @@ class Inference:
|
||||||
|
|
||||||
if isinstance(filename, str):
|
if isinstance(filename, str):
|
||||||
filename = Path(filename)
|
filename = Path(filename)
|
||||||
|
|
||||||
|
parent, name = filename.parent, "cleaned_"+filename.name
|
||||||
|
filename = parent/Path(name)
|
||||||
if filename.is_file():
|
if filename.is_file():
|
||||||
raise FileExistsError(f"file {filename} already exists")
|
raise FileExistsError(f"file {filename} already exists")
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
|
|
@ -1,2 +1,3 @@
|
||||||
from enhancer.models.demucs import Demucs
|
from enhancer.models.demucs import Demucs
|
||||||
from enhancer.models.waveunet import WaveUnet
|
from enhancer.models.waveunet import WaveUnet
|
||||||
|
from enhancer.models.model import Model
|
||||||
|
|
@ -191,8 +191,8 @@ class Model(pl.LightningModule):
|
||||||
map_location = torch.device(DEFAULT_DEVICE)
|
map_location = torch.device(DEFAULT_DEVICE)
|
||||||
|
|
||||||
loaded_checkpoint = pl_load(model_path_pl,map_location)
|
loaded_checkpoint = pl_load(model_path_pl,map_location)
|
||||||
module_name = loaded_checkpoint["architecture"]["module"]
|
module_name = loaded_checkpoint["enhancer"]["architecture"]["module"]
|
||||||
class_name = loaded_checkpoint["architecture"]["class"]
|
class_name = loaded_checkpoint["enhancer"]["architecture"]["class"]
|
||||||
module = import_module(module_name)
|
module = import_module(module_name)
|
||||||
Klass = getattr(module, class_name)
|
Klass = getattr(module, class_name)
|
||||||
|
|
||||||
|
|
@ -216,10 +216,11 @@ class Model(pl.LightningModule):
|
||||||
batch_predictions = []
|
batch_predictions = []
|
||||||
self.eval().to(self.device)
|
self.eval().to(self.device)
|
||||||
|
|
||||||
for batch_id in range(batch.shape[0],batch_size):
|
with torch.no_grad():
|
||||||
batch_data = batch[batch_id:batch_id+batch_size,:,:].to(self.device)
|
for batch_id in range(0,batch.shape[0],batch_size):
|
||||||
prediction = self(batch_data)
|
batch_data = batch[batch_id:batch_id+batch_size,:,:].to(self.device)
|
||||||
batch_predictions.append(prediction)
|
prediction = self(batch_data)
|
||||||
|
batch_predictions.append(prediction)
|
||||||
|
|
||||||
return torch.vstack(batch_predictions)
|
return torch.vstack(batch_predictions)
|
||||||
|
|
||||||
|
|
@ -232,9 +233,9 @@ class Model(pl.LightningModule):
|
||||||
duration:Optional[int]=None,
|
duration:Optional[int]=None,
|
||||||
step_size:Optional[int]=None,):
|
step_size:Optional[int]=None,):
|
||||||
|
|
||||||
model_sampling_rate = self.model.hprams("sampling_rate")
|
model_sampling_rate = self.hparams["sampling_rate"]
|
||||||
if duration is None:
|
if duration is None:
|
||||||
duration = self.model.hparams("duration")
|
duration = self.hparams["duration"]
|
||||||
waveform = Inference.read_input(audio,sampling_rate,model_sampling_rate)
|
waveform = Inference.read_input(audio,sampling_rate,model_sampling_rate)
|
||||||
waveform.to(self.device)
|
waveform.to(self.device)
|
||||||
window_size = round(duration * model_sampling_rate)
|
window_size = round(duration * model_sampling_rate)
|
||||||
|
|
@ -246,8 +247,9 @@ class Model(pl.LightningModule):
|
||||||
Inference.write_output(waveform,audio,model_sampling_rate)
|
Inference.write_output(waveform,audio,model_sampling_rate)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
waveform = Inference.prepare_output(waveform, model_sampling_rate,
|
||||||
|
audio, sampling_rate)
|
||||||
return waveform
|
return waveform
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def valid_monitor(self):
|
def valid_monitor(self):
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue