Merge branch 'dev' of https://github.com/shahules786/enhancer into dev-hawk
This commit is contained in:
commit
94d70c4ddf
|
|
@ -132,11 +132,14 @@ class TaskDataset(pl.LightningDataModule):
|
||||||
for item in data:
|
for item in data:
|
||||||
clean, noisy, total_dur = item.values()
|
clean, noisy, total_dur = item.values()
|
||||||
if total_dur < self.duration:
|
if total_dur < self.duration:
|
||||||
continue
|
metadata.append(({"clean": clean, "noisy": noisy}, 0.0))
|
||||||
|
else:
|
||||||
num_segments = round(total_dur / self.duration)
|
num_segments = round(total_dur / self.duration)
|
||||||
for index in range(num_segments):
|
for index in range(num_segments):
|
||||||
start_time = index * self.duration
|
start_time = index * self.duration
|
||||||
metadata.append(({"clean": clean, "noisy": noisy}, start_time))
|
metadata.append(
|
||||||
|
({"clean": clean, "noisy": noisy}, start_time)
|
||||||
|
)
|
||||||
return metadata
|
return metadata
|
||||||
|
|
||||||
def train_dataloader(self):
|
def train_dataloader(self):
|
||||||
|
|
@ -195,6 +198,7 @@ class EnhancerDataset(TaskDataset):
|
||||||
files: Files,
|
files: Files,
|
||||||
valid_minutes=5.0,
|
valid_minutes=5.0,
|
||||||
duration=1.0,
|
duration=1.0,
|
||||||
|
stride=0.5,
|
||||||
sampling_rate=48000,
|
sampling_rate=48000,
|
||||||
matching_function=None,
|
matching_function=None,
|
||||||
batch_size=32,
|
batch_size=32,
|
||||||
|
|
@ -217,6 +221,7 @@ class EnhancerDataset(TaskDataset):
|
||||||
self.files = files
|
self.files = files
|
||||||
self.duration = max(1.0, duration)
|
self.duration = max(1.0, duration)
|
||||||
self.audio = Audio(self.sampling_rate, mono=True, return_tensor=True)
|
self.audio = Audio(self.sampling_rate, mono=True, return_tensor=True)
|
||||||
|
self.stride = stride or duration
|
||||||
|
|
||||||
def setup(self, stage: Optional[str] = None):
|
def setup(self, stage: Optional[str] = None):
|
||||||
|
|
||||||
|
|
@ -234,9 +239,22 @@ class EnhancerDataset(TaskDataset):
|
||||||
weights=[file["duration"] for file in self.train_data],
|
weights=[file["duration"] for file in self.train_data],
|
||||||
)
|
)
|
||||||
file_duration = file_dict["duration"]
|
file_duration = file_dict["duration"]
|
||||||
start_time = round(rng.uniform(0, file_duration - self.duration), 2)
|
num_segments = self.get_num_segments(
|
||||||
data = self.prepare_segment(file_dict, start_time)
|
file_duration, self.duration, self.stride
|
||||||
yield data
|
)
|
||||||
|
for index in range(0, num_segments):
|
||||||
|
start_time = index * self.stride
|
||||||
|
yield self.prepare_segment(file_dict, start_time)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_num_segments(file_duration, duration, stride):
|
||||||
|
|
||||||
|
if file_duration < duration:
|
||||||
|
num_segments = 1
|
||||||
|
else:
|
||||||
|
num_segments = math.ceil((file_duration - duration) / stride) + 1
|
||||||
|
|
||||||
|
return num_segments
|
||||||
|
|
||||||
def val__getitem__(self, idx):
|
def val__getitem__(self, idx):
|
||||||
return self.prepare_segment(*self._validation[idx])
|
return self.prepare_segment(*self._validation[idx])
|
||||||
|
|
@ -273,8 +291,16 @@ class EnhancerDataset(TaskDataset):
|
||||||
return {"clean": clean_segment, "noisy": noisy_segment}
|
return {"clean": clean_segment, "noisy": noisy_segment}
|
||||||
|
|
||||||
def train__len__(self):
|
def train__len__(self):
|
||||||
|
|
||||||
return math.ceil(
|
return math.ceil(
|
||||||
sum([file["duration"] for file in self.train_data]) / self.duration
|
sum(
|
||||||
|
[
|
||||||
|
self.get_num_segments(
|
||||||
|
file["duration"], self.duration, self.stride
|
||||||
|
)
|
||||||
|
for file in self.train_data
|
||||||
|
]
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
def val__len__(self):
|
def val__len__(self):
|
||||||
|
|
|
||||||
|
|
@ -133,10 +133,12 @@ class Demucs(Model):
|
||||||
num_channels: int = 1,
|
num_channels: int = 1,
|
||||||
resample: int = 4,
|
resample: int = 4,
|
||||||
sampling_rate=16000,
|
sampling_rate=16000,
|
||||||
|
normalize=True,
|
||||||
lr: float = 1e-3,
|
lr: float = 1e-3,
|
||||||
dataset: Optional[EnhancerDataset] = None,
|
dataset: Optional[EnhancerDataset] = None,
|
||||||
loss: Union[str, List] = "mse",
|
loss: Union[str, List] = "mse",
|
||||||
metric: Union[str, List] = "mse",
|
metric: Union[str, List] = "mse",
|
||||||
|
floor=1e-3,
|
||||||
):
|
):
|
||||||
duration = (
|
duration = (
|
||||||
dataset.duration if isinstance(dataset, EnhancerDataset) else None
|
dataset.duration if isinstance(dataset, EnhancerDataset) else None
|
||||||
|
|
@ -161,6 +163,8 @@ class Demucs(Model):
|
||||||
lstm = merge_dict(self.LSTM_DEFAULTS, lstm)
|
lstm = merge_dict(self.LSTM_DEFAULTS, lstm)
|
||||||
self.save_hyperparameters("encoder_decoder", "lstm", "resample")
|
self.save_hyperparameters("encoder_decoder", "lstm", "resample")
|
||||||
hidden = encoder_decoder["initial_output_channels"]
|
hidden = encoder_decoder["initial_output_channels"]
|
||||||
|
self.normalize = normalize
|
||||||
|
self.floor = floor
|
||||||
self.encoder = nn.ModuleList()
|
self.encoder = nn.ModuleList()
|
||||||
self.decoder = nn.ModuleList()
|
self.decoder = nn.ModuleList()
|
||||||
|
|
||||||
|
|
@ -204,7 +208,10 @@ class Demucs(Model):
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
f"Demucs can only process mono channel audio, input has {waveform.size(1)} channels"
|
f"Demucs can only process mono channel audio, input has {waveform.size(1)} channels"
|
||||||
)
|
)
|
||||||
|
if self.normalize:
|
||||||
|
waveform = waveform.mean(dim=1, keepdim=True)
|
||||||
|
std = waveform.std(dim=-1, keepdim=True)
|
||||||
|
waveform = waveform / (self.floor + std)
|
||||||
length = waveform.shape[-1]
|
length = waveform.shape[-1]
|
||||||
x = F.pad(waveform, (0, self.get_padding_length(length) - length))
|
x = F.pad(waveform, (0, self.get_padding_length(length) - length))
|
||||||
if self.hparams.resample > 1:
|
if self.hparams.resample > 1:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue