Merge pull request #13 from shahules786/dev-reformat

flake8 cli/ tests/
This commit is contained in:
Shahul ES 2022-10-05 15:56:01 +05:30 committed by GitHub
commit 1b30d7f3ed
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 180 additions and 142 deletions

View File

@ -1,67 +0,0 @@
from genericpath import isfile
import os
from types import MethodType
import hydra
from hydra.utils import instantiate
from omegaconf import DictConfig
from torch.optim.lr_scheduler import ReduceLROnPlateau
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.loggers import MLFlowLogger
os.environ["HYDRA_FULL_ERROR"] = "1"
JOB_ID = os.environ.get("SLURM_JOBID","0")
@hydra.main(config_path="train_config",config_name="config")
def main(config: DictConfig):
callbacks = []
logger = MLFlowLogger(experiment_name=config.mlflow.experiment_name,
run_name=config.mlflow.run_name, tags={"JOB_ID":JOB_ID})
parameters = config.hyperparameters
dataset = instantiate(config.dataset)
model = instantiate(config.model,dataset=dataset,lr=parameters.get("lr"),
loss=parameters.get("loss"), metric = parameters.get("metric"))
direction = model.valid_monitor
checkpoint = ModelCheckpoint(
dirpath="./model",filename=f"model_{JOB_ID}",monitor="val_loss",verbose=True,
mode=direction,every_n_epochs=1
)
callbacks.append(checkpoint)
early_stopping = EarlyStopping(
monitor="val_loss",
mode=direction,
min_delta=0.0,
patience=parameters.get("EarlyStopping_patience",10),
strict=True,
verbose=False,
)
callbacks.append(early_stopping)
def configure_optimizer(self):
optimizer = instantiate(config.optimizer,lr=parameters.get("lr"),parameters=self.parameters())
scheduler = ReduceLROnPlateau(
optimizer=optimizer,
mode=direction,
factor=parameters.get("ReduceLr_factor",0.1),
verbose=True,
min_lr=parameters.get("min_lr",1e-6),
patience=parameters.get("ReduceLr_patience",3)
)
return {"optimizer":optimizer, "lr_scheduler":scheduler}
model.configure_parameters = MethodType(configure_optimizer,model)
trainer = instantiate(config.trainer,logger=logger,callbacks=callbacks)
trainer.fit(model)
saved_location = os.path.join(trainer.default_root_dir,"model",f"model_{JOB_ID}.ckpt")
if os.path.isfile(saved_location):
logger.experiment.log_artifact(logger.run_id,saved_location)
if __name__=="__main__":
main()

84
enhancer/cli/train.py Normal file
View File

@ -0,0 +1,84 @@
import os
from types import MethodType
import hydra
from hydra.utils import instantiate
from omegaconf import DictConfig
from torch.optim.lr_scheduler import ReduceLROnPlateau
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.loggers import MLFlowLogger
os.environ["HYDRA_FULL_ERROR"] = "1"
JOB_ID = os.environ.get("SLURM_JOBID", "0")
@hydra.main(config_path="train_config", config_name="config")
def main(config: DictConfig):
callbacks = []
logger = MLFlowLogger(
experiment_name=config.mlflow.experiment_name,
run_name=config.mlflow.run_name,
tags={"JOB_ID": JOB_ID},
)
parameters = config.hyperparameters
dataset = instantiate(config.dataset)
model = instantiate(
config.model,
dataset=dataset,
lr=parameters.get("lr"),
loss=parameters.get("loss"),
metric=parameters.get("metric"),
)
direction = model.valid_monitor
checkpoint = ModelCheckpoint(
dirpath="./model",
filename=f"model_{JOB_ID}",
monitor="val_loss",
verbose=True,
mode=direction,
every_n_epochs=1,
)
callbacks.append(checkpoint)
early_stopping = EarlyStopping(
monitor="val_loss",
mode=direction,
min_delta=0.0,
patience=parameters.get("EarlyStopping_patience", 10),
strict=True,
verbose=False,
)
callbacks.append(early_stopping)
def configure_optimizer(self):
optimizer = instantiate(
config.optimizer,
lr=parameters.get("lr"),
parameters=self.parameters(),
)
scheduler = ReduceLROnPlateau(
optimizer=optimizer,
mode=direction,
factor=parameters.get("ReduceLr_factor", 0.1),
verbose=True,
min_lr=parameters.get("min_lr", 1e-6),
patience=parameters.get("ReduceLr_patience", 3),
)
return {"optimizer": optimizer, "lr_scheduler": scheduler}
model.configure_parameters = MethodType(configure_optimizer, model)
trainer = instantiate(config.trainer, logger=logger, callbacks=callbacks)
trainer.fit(model)
saved_location = os.path.join(
trainer.default_root_dir, "model", f"model_{JOB_ID}.ckpt"
)
if os.path.isfile(saved_location):
logger.experiment.log_artifact(logger.run_id, saved_location)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,13 @@
_target_: enhancer.data.dataset.EnhancerDataset
name : vctk
root_dir : /Users/shahules/Myprojects/enhancer/datasets/vctk
duration : 1.0
sampling_rate: 16000
batch_size: 64
num_workers : 0
files:
train_clean : clean_testset_wav
test_clean : clean_testset_wav
train_noisy : noisy_testset_wav
test_noisy : noisy_testset_wav

View File

@ -6,26 +6,28 @@ from enhancer.loss import mean_absolute_error, mean_squared_error
loss_functions = [mean_absolute_error(), mean_squared_error()] loss_functions = [mean_absolute_error(), mean_squared_error()]
def check_loss_shapes_compatibility(loss_fun): def check_loss_shapes_compatibility(loss_fun):
batch_size = 4 batch_size = 4
shape = (1,1000) shape = (1, 1000)
loss_fun(torch.rand(batch_size,*shape),torch.rand(batch_size,*shape)) loss_fun(torch.rand(batch_size, *shape), torch.rand(batch_size, *shape))
with pytest.raises(TypeError): with pytest.raises(TypeError):
loss_fun(torch.rand(4,*shape),torch.rand(6,*shape)) loss_fun(torch.rand(4, *shape), torch.rand(6, *shape))
@pytest.mark.parametrize("loss",loss_functions) @pytest.mark.parametrize("loss", loss_functions)
def test_loss_input_shapes(loss): def test_loss_input_shapes(loss):
check_loss_shapes_compatibility(loss) check_loss_shapes_compatibility(loss)
@pytest.mark.parametrize("loss",loss_functions)
@pytest.mark.parametrize("loss", loss_functions)
def test_loss_output_type(loss): def test_loss_output_type(loss):
batch_size = 4 batch_size = 4
prediction, target = torch.rand(batch_size,1,1000),torch.rand(batch_size,1,1000) prediction, target = torch.rand(batch_size, 1, 1000), torch.rand(
batch_size, 1, 1000
)
loss_value = loss(prediction, target) loss_value = loss(prediction, target)
assert isinstance(loss_value.item(),float) assert isinstance(loss_value.item(), float)

View File

@ -10,37 +10,35 @@ from enhancer.data.dataset import EnhancerDataset
@pytest.fixture @pytest.fixture
def vctk_dataset(): def vctk_dataset():
root_dir = "tests/data/vctk" root_dir = "tests/data/vctk"
files = Files(train_clean="clean_testset_wav",train_noisy="noisy_testset_wav", files = Files(
test_clean="clean_testset_wav", test_noisy="noisy_testset_wav") train_clean="clean_testset_wav",
dataset = EnhancerDataset(name="vctk",root_dir=root_dir,files=files) train_noisy="noisy_testset_wav",
test_clean="clean_testset_wav",
test_noisy="noisy_testset_wav",
)
dataset = EnhancerDataset(name="vctk", root_dir=root_dir, files=files)
return dataset return dataset
@pytest.mark.parametrize("batch_size,samples",[(1,1000)]) @pytest.mark.parametrize("batch_size,samples", [(1, 1000)])
def test_forward(batch_size,samples): def test_forward(batch_size, samples):
model = Demucs() model = Demucs()
model.eval() model.eval()
data = torch.rand(batch_size,1,samples,requires_grad=False) data = torch.rand(batch_size, 1, samples, requires_grad=False)
with torch.no_grad(): with torch.no_grad():
_ = model(data) _ = model(data)
data = torch.rand(batch_size,2,samples,requires_grad=False) data = torch.rand(batch_size, 2, samples, requires_grad=False)
with torch.no_grad(): with torch.no_grad():
with pytest.raises(TypeError): with pytest.raises(TypeError):
_ = model(data) _ = model(data)
@pytest.mark.parametrize("dataset,channels,loss", @pytest.mark.parametrize(
[(pytest.lazy_fixture("vctk_dataset"),1,["mae","mse"])]) "dataset,channels,loss",
def test_demucs_init(dataset,channels,loss): [(pytest.lazy_fixture("vctk_dataset"), 1, ["mae", "mse"])],
)
def test_demucs_init(dataset, channels, loss):
with torch.no_grad(): with torch.no_grad():
model = Demucs(num_channels=channels,dataset=dataset,loss=loss) model = Demucs(num_channels=channels, dataset=dataset, loss=loss)

View File

@ -10,37 +10,35 @@ from enhancer.data.dataset import EnhancerDataset
@pytest.fixture @pytest.fixture
def vctk_dataset(): def vctk_dataset():
root_dir = "tests/data/vctk" root_dir = "tests/data/vctk"
files = Files(train_clean="clean_testset_wav",train_noisy="noisy_testset_wav", files = Files(
test_clean="clean_testset_wav", test_noisy="noisy_testset_wav") train_clean="clean_testset_wav",
dataset = EnhancerDataset(name="vctk",root_dir=root_dir,files=files) train_noisy="noisy_testset_wav",
test_clean="clean_testset_wav",
test_noisy="noisy_testset_wav",
)
dataset = EnhancerDataset(name="vctk", root_dir=root_dir, files=files)
return dataset return dataset
@pytest.mark.parametrize("batch_size,samples",[(1,1000)]) @pytest.mark.parametrize("batch_size,samples", [(1, 1000)])
def test_forward(batch_size,samples): def test_forward(batch_size, samples):
model = WaveUnet() model = WaveUnet()
model.eval() model.eval()
data = torch.rand(batch_size,1,samples,requires_grad=False) data = torch.rand(batch_size, 1, samples, requires_grad=False)
with torch.no_grad(): with torch.no_grad():
_ = model(data) _ = model(data)
data = torch.rand(batch_size,2,samples,requires_grad=False) data = torch.rand(batch_size, 2, samples, requires_grad=False)
with torch.no_grad(): with torch.no_grad():
with pytest.raises(TypeError): with pytest.raises(TypeError):
_ = model(data) _ = model(data)
@pytest.mark.parametrize("dataset,channels,loss", @pytest.mark.parametrize(
[(pytest.lazy_fixture("vctk_dataset"),1,["mae","mse"])]) "dataset,channels,loss",
def test_demucs_init(dataset,channels,loss): [(pytest.lazy_fixture("vctk_dataset"), 1, ["mae", "mse"])],
)
def test_demucs_init(dataset, channels, loss):
with torch.no_grad(): with torch.no_grad():
model = WaveUnet(num_channels=channels,dataset=dataset,loss=loss) model = WaveUnet(num_channels=channels, dataset=dataset, loss=loss)

View File

@ -4,22 +4,26 @@ import torch
from enhancer.inference import Inference from enhancer.inference import Inference
@pytest.mark.parametrize("audio",["tests/data/vctk/clean_testset_wav/p257_166.wav",torch.rand(1,2,48000)]) @pytest.mark.parametrize(
"audio",
["tests/data/vctk/clean_testset_wav/p257_166.wav", torch.rand(1, 2, 48000)],
)
def test_read_input(audio): def test_read_input(audio):
read_audio = Inference.read_input(audio,48000,16000) read_audio = Inference.read_input(audio, 48000, 16000)
assert isinstance(read_audio,torch.Tensor) assert isinstance(read_audio, torch.Tensor)
assert read_audio.shape[0] == 1 assert read_audio.shape[0] == 1
def test_batchify(): def test_batchify():
rand = torch.rand(1,1000) rand = torch.rand(1, 1000)
batched_rand = Inference.batchify(rand, window_size = 100, step_size=100) batched_rand = Inference.batchify(rand, window_size=100, step_size=100)
assert batched_rand.shape[0] == 12 assert batched_rand.shape[0] == 12
def test_aggregate(): def test_aggregate():
rand = torch.rand(12,1,100) rand = torch.rand(12, 1, 100)
agg_rand = Inference.aggreagate(data=rand,window_size=100,total_frames=1000,step_size=100) agg_rand = Inference.aggreagate(
data=rand, window_size=100, total_frames=1000, step_size=100
)
assert agg_rand.shape[-1] == 1000 assert agg_rand.shape[-1] == 1000

View File

@ -7,40 +7,46 @@ from enhancer.utils.io import Audio
from enhancer.utils.config import Files from enhancer.utils.config import Files
from enhancer.data.fileprocessor import Fileprocessor from enhancer.data.fileprocessor import Fileprocessor
def test_io_channel(): def test_io_channel():
input_audio = np.random.rand(2,32000) input_audio = np.random.rand(2, 32000)
audio = Audio(mono=True,return_tensor=False) audio = Audio(mono=True, return_tensor=False)
output_audio = audio(input_audio) output_audio = audio(input_audio)
assert output_audio.shape[0] == 1 assert output_audio.shape[0] == 1
def test_io_resampling(): def test_io_resampling():
input_audio = np.random.rand(1,32000) input_audio = np.random.rand(1, 32000)
resampled_audio = Audio.resample_audio(input_audio,16000,8000) resampled_audio = Audio.resample_audio(input_audio, 16000, 8000)
input_audio = torch.rand(1,32000) input_audio = torch.rand(1, 32000)
resampled_audio_pt = Audio.resample_audio(input_audio,16000,8000) resampled_audio_pt = Audio.resample_audio(input_audio, 16000, 8000)
assert resampled_audio.shape[1] == resampled_audio_pt.size(1) == 16000 assert resampled_audio.shape[1] == resampled_audio_pt.size(1) == 16000
def test_fileprocessor_vctk(): def test_fileprocessor_vctk():
fp = Fileprocessor.from_name("vctk","tests/data/vctk/clean_testset_wav", fp = Fileprocessor.from_name(
"tests/data/vctk/noisy_testset_wav",48000) "vctk",
"tests/data/vctk/clean_testset_wav",
"tests/data/vctk/noisy_testset_wav",
48000,
)
matching_dict = fp.prepare_matching_dict() matching_dict = fp.prepare_matching_dict()
assert len(matching_dict)==2 assert len(matching_dict) == 2
@pytest.mark.parametrize("dataset_name",["vctk","dns-2020"])
@pytest.mark.parametrize("dataset_name", ["vctk", "dns-2020"])
def test_fileprocessor_names(dataset_name): def test_fileprocessor_names(dataset_name):
fp = Fileprocessor.from_name(dataset_name,"clean_dir","noisy_dir",16000) fp = Fileprocessor.from_name(dataset_name, "clean_dir", "noisy_dir", 16000)
assert hasattr(fp.matching_function, '__call__') assert hasattr(fp.matching_function, "__call__")
def test_fileprocessor_invaliname(): def test_fileprocessor_invaliname():
with pytest.raises(ValueError): with pytest.raises(ValueError):
fp = Fileprocessor.from_name("undefined","clean_dir","noisy_dir",16000).prepare_matching_dict() fp = Fileprocessor.from_name(
"undefined", "clean_dir", "noisy_dir", 16000
).prepare_matching_dict()