debug dataset
This commit is contained in:
parent
c06566c132
commit
9ed1b9d3f7
|
|
@ -3,7 +3,6 @@ from dataclasses import dataclass
|
||||||
import glob
|
import glob
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
from typing_extensions import dataclass_transform
|
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
from torch.utils.data import IterableDataset, DataLoader, Dataset
|
from torch.utils.data import IterableDataset, DataLoader, Dataset
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
@ -55,15 +54,24 @@ class TaskDataset(pl.LightningDataModule):
|
||||||
self.sampling_rate = sampling_rate
|
self.sampling_rate = sampling_rate
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.matching_function = matching_function
|
self.matching_function = matching_function
|
||||||
|
self._validation = []
|
||||||
|
|
||||||
def setup(self, stage: Optional[str] = None):
|
def setup(self, stage: Optional[str] = None):
|
||||||
|
|
||||||
if stage in ("fit",None):
|
if stage in ("fit",None):
|
||||||
|
|
||||||
fp = Fileprocessor.from_name(self.name,self.files.train_clean,self.files.train_noisy,self.matching_function)
|
train_clean = os.path.join(self.root_dir,self.files.train_clean)
|
||||||
|
train_noisy = os.path.join(self.root_dir,self.files.train_noisy)
|
||||||
|
fp = Fileprocessor.from_name(self.name,train_clean,
|
||||||
|
train_noisy,self.sampling_rate,
|
||||||
|
self.matching_function)
|
||||||
self.train_data = fp.prepare_matching_dict()
|
self.train_data = fp.prepare_matching_dict()
|
||||||
|
|
||||||
fp = Fileprocessor.from_name(self.name,self.files.test_clean,self.files.test_noisy,self.matching_function)
|
val_clean = os.path.join(self.root_dir,self.files.test_clean)
|
||||||
|
val_noisy = os.path.join(self.root_dir,self.files.test_noisy)
|
||||||
|
fp = Fileprocessor.from_name(self.name,val_clean,
|
||||||
|
val_noisy,self.sampling_rate,
|
||||||
|
self.matching_function)
|
||||||
val_data = fp.prepare_matching_dict()
|
val_data = fp.prepare_matching_dict()
|
||||||
|
|
||||||
for item in val_data:
|
for item in val_data:
|
||||||
|
|
@ -116,7 +124,7 @@ class EnhancerDataset(TaskDataset):
|
||||||
|
|
||||||
def train__iter__(self):
|
def train__iter__(self):
|
||||||
|
|
||||||
rng = create_unique_rng(12) ##pass epoch number here
|
rng = create_unique_rng(self.model.current_epoch)
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
|
|
||||||
|
|
@ -141,7 +149,7 @@ class EnhancerDataset(TaskDataset):
|
||||||
return {"clean": clean_segment,"noisy":noisy_segment}
|
return {"clean": clean_segment,"noisy":noisy_segment}
|
||||||
|
|
||||||
def train__len__(self):
|
def train__len__(self):
|
||||||
return math.ceil(sum([file["duration"] for file in self.valid_files])/self.duration)
|
return math.ceil(sum([file["duration"] for file in self.train_data])/self.duration)
|
||||||
|
|
||||||
def val__len__(self):
|
def val__len__(self):
|
||||||
return len(self._validation)
|
return len(self._validation)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue