merge fix
This commit is contained in:
commit
c256b5478e
|
|
@ -1,3 +1,4 @@
|
|||
from genericpath import isfile
|
||||
import os
|
||||
from types import MethodType
|
||||
import hydra
|
||||
|
|
@ -7,7 +8,7 @@ from torch.optim.lr_scheduler import ReduceLROnPlateau
|
|||
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
|
||||
from pytorch_lightning.loggers import MLFlowLogger
|
||||
os.environ["HYDRA_FULL_ERROR"] = "1"
|
||||
JOB_ID = os.environ.get("SLURM_JOBID")
|
||||
JOB_ID = os.environ.get("SLURM_JOBID","0")
|
||||
|
||||
@hydra.main(config_path="train_config",config_name="config")
|
||||
def main(config: DictConfig):
|
||||
|
|
@ -55,8 +56,10 @@ def main(config: DictConfig):
|
|||
|
||||
trainer = instantiate(config.trainer,logger=logger,callbacks=callbacks)
|
||||
trainer.fit(model)
|
||||
if os.path.exists("./model/"):
|
||||
logger.experiment.log_artifact(logger.run_id,f"./model/.*")
|
||||
|
||||
saved_location = os.path.join(trainer.default_root_dir,"model",f"model_{JOB_ID}.ckpt")
|
||||
if os.path.isfile(saved_location):
|
||||
logger.experiment.log_artifact(logger.run_id,saved_location)
|
||||
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
defaults:
|
||||
- model : Demucs
|
||||
- model : WaveUnet
|
||||
- dataset : Vctk
|
||||
- optimizer : Adam
|
||||
- hyperparameters : default
|
||||
|
|
|
|||
|
|
@ -3,12 +3,12 @@ name : vctk
|
|||
root_dir : /scratch/c.sistc3/DS_10283_2791
|
||||
duration : 1.0
|
||||
sampling_rate: 16000
|
||||
batch_size: 8
|
||||
batch_size: 64
|
||||
|
||||
files:
|
||||
train_clean : clean_trainset_56spk_wav
|
||||
train_clean : clean_trainset_28spk_wav
|
||||
test_clean : clean_testset_wav
|
||||
train_noisy : noisy_trainset_56spk_wav
|
||||
train_noisy : noisy_trainset_28spk_wav
|
||||
test_noisy : noisy_testset_wav
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
loss : mse
|
||||
metric : mae
|
||||
lr : 0.0001
|
||||
num_epochs : 100
|
||||
ReduceLr_patience : 5
|
||||
ReduceLr_factor : 0.1
|
||||
min_lr : 0.000001
|
||||
|
|
|
|||
|
|
@ -22,8 +22,8 @@ limit_predict_batches: 1.0
|
|||
limit_test_batches: 1.0
|
||||
limit_train_batches: 1.0
|
||||
limit_val_batches: 1.0
|
||||
log_every_n_steps: 10
|
||||
max_epochs: 100
|
||||
log_every_n_steps: 1
|
||||
max_epochs: 10
|
||||
max_steps: null
|
||||
max_time: null
|
||||
min_epochs: 1
|
||||
|
|
|
|||
|
|
@ -27,6 +27,8 @@ class Inference:
|
|||
|
||||
if isinstance(audio, (np.ndarray, torch.Tensor)):
|
||||
assert sr is not None, "Invalid sampling rate!"
|
||||
if len(audio.shape) == 1:
|
||||
audio = audio.reshape(1,-1)
|
||||
|
||||
if isinstance(audio, str):
|
||||
audio = Path(audio)
|
||||
|
|
@ -103,6 +105,8 @@ class Inference:
|
|||
window = get_window(window=window, Nx=data.shape[-1])
|
||||
window = torch.from_numpy(window).to(data.device)
|
||||
data *= window
|
||||
step_size = window_size//2 if step_size is None else step_size
|
||||
|
||||
|
||||
data = data.permute(1, 2, 0)
|
||||
data = F.fold(
|
||||
|
|
@ -129,6 +133,9 @@ class Inference:
|
|||
|
||||
if isinstance(filename, str):
|
||||
filename = Path(filename)
|
||||
|
||||
parent, name = filename.parent, "cleaned_"+filename.name
|
||||
filename = parent/Path(name)
|
||||
if filename.is_file():
|
||||
raise FileExistsError(f"file {filename} already exists")
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -1,2 +1,3 @@
|
|||
from enhancer.models.demucs import Demucs
|
||||
from enhancer.models.waveunet import WaveUnet
|
||||
from enhancer.models.model import Model
|
||||
|
|
@ -191,8 +191,8 @@ class Model(pl.LightningModule):
|
|||
map_location = torch.device(DEFAULT_DEVICE)
|
||||
|
||||
loaded_checkpoint = pl_load(model_path_pl,map_location)
|
||||
module_name = loaded_checkpoint["architecture"]["module"]
|
||||
class_name = loaded_checkpoint["architecture"]["class"]
|
||||
module_name = loaded_checkpoint["enhancer"]["architecture"]["module"]
|
||||
class_name = loaded_checkpoint["enhancer"]["architecture"]["class"]
|
||||
module = import_module(module_name)
|
||||
Klass = getattr(module, class_name)
|
||||
|
||||
|
|
@ -216,7 +216,8 @@ class Model(pl.LightningModule):
|
|||
batch_predictions = []
|
||||
self.eval().to(self.device)
|
||||
|
||||
for batch_id in range(batch.shape[0],batch_size):
|
||||
with torch.no_grad():
|
||||
for batch_id in range(0,batch.shape[0],batch_size):
|
||||
batch_data = batch[batch_id:batch_id+batch_size,:,:].to(self.device)
|
||||
prediction = self(batch_data)
|
||||
batch_predictions.append(prediction)
|
||||
|
|
@ -232,9 +233,9 @@ class Model(pl.LightningModule):
|
|||
duration:Optional[int]=None,
|
||||
step_size:Optional[int]=None,):
|
||||
|
||||
model_sampling_rate = self.model.hprams("sampling_rate")
|
||||
model_sampling_rate = self.hparams["sampling_rate"]
|
||||
if duration is None:
|
||||
duration = self.model.hparams("duration")
|
||||
duration = self.hparams["duration"]
|
||||
waveform = Inference.read_input(audio,sampling_rate,model_sampling_rate)
|
||||
waveform.to(self.device)
|
||||
window_size = round(duration * model_sampling_rate)
|
||||
|
|
@ -246,8 +247,9 @@ class Model(pl.LightningModule):
|
|||
Inference.write_output(waveform,audio,model_sampling_rate)
|
||||
|
||||
else:
|
||||
waveform = Inference.prepare_output(waveform, model_sampling_rate,
|
||||
audio, sampling_rate)
|
||||
return waveform
|
||||
|
||||
@property
|
||||
def valid_monitor(self):
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue