debug dataset
This commit is contained in:
parent
c06566c132
commit
9ed1b9d3f7
|
|
@ -3,7 +3,6 @@ from dataclasses import dataclass
|
|||
import glob
|
||||
import math
|
||||
import os
|
||||
from typing_extensions import dataclass_transform
|
||||
import pytorch_lightning as pl
|
||||
from torch.utils.data import IterableDataset, DataLoader, Dataset
|
||||
import torch.nn.functional as F
|
||||
|
|
@ -55,15 +54,24 @@ class TaskDataset(pl.LightningDataModule):
|
|||
self.sampling_rate = sampling_rate
|
||||
self.batch_size = batch_size
|
||||
self.matching_function = matching_function
|
||||
self._validation = []
|
||||
|
||||
def setup(self, stage: Optional[str] = None):
|
||||
|
||||
if stage in ("fit",None):
|
||||
|
||||
fp = Fileprocessor.from_name(self.name,self.files.train_clean,self.files.train_noisy,self.matching_function)
|
||||
train_clean = os.path.join(self.root_dir,self.files.train_clean)
|
||||
train_noisy = os.path.join(self.root_dir,self.files.train_noisy)
|
||||
fp = Fileprocessor.from_name(self.name,train_clean,
|
||||
train_noisy,self.sampling_rate,
|
||||
self.matching_function)
|
||||
self.train_data = fp.prepare_matching_dict()
|
||||
|
||||
fp = Fileprocessor.from_name(self.name,self.files.test_clean,self.files.test_noisy,self.matching_function)
|
||||
val_clean = os.path.join(self.root_dir,self.files.test_clean)
|
||||
val_noisy = os.path.join(self.root_dir,self.files.test_noisy)
|
||||
fp = Fileprocessor.from_name(self.name,val_clean,
|
||||
val_noisy,self.sampling_rate,
|
||||
self.matching_function)
|
||||
val_data = fp.prepare_matching_dict()
|
||||
|
||||
for item in val_data:
|
||||
|
|
@ -116,7 +124,7 @@ class EnhancerDataset(TaskDataset):
|
|||
|
||||
def train__iter__(self):
|
||||
|
||||
rng = create_unique_rng(12) ##pass epoch number here
|
||||
rng = create_unique_rng(self.model.current_epoch)
|
||||
|
||||
while True:
|
||||
|
||||
|
|
@ -141,7 +149,7 @@ class EnhancerDataset(TaskDataset):
|
|||
return {"clean": clean_segment,"noisy":noisy_segment}
|
||||
|
||||
def train__len__(self):
|
||||
return math.ceil(sum([file["duration"] for file in self.valid_files])/self.duration)
|
||||
return math.ceil(sum([file["duration"] for file in self.train_data])/self.duration)
|
||||
|
||||
def val__len__(self):
|
||||
return len(self._validation)
|
||||
|
|
|
|||
Loading…
Reference in New Issue