debug dataset

This commit is contained in:
shahules786 2022-09-12 10:54:52 +05:30
parent c06566c132
commit 9ed1b9d3f7
1 changed files with 13 additions and 5 deletions

View File

@ -3,7 +3,6 @@ from dataclasses import dataclass
import glob
import math
import os
from typing_extensions import dataclass_transform
import pytorch_lightning as pl
from torch.utils.data import IterableDataset, DataLoader, Dataset
import torch.nn.functional as F
@ -55,15 +54,24 @@ class TaskDataset(pl.LightningDataModule):
self.sampling_rate = sampling_rate
self.batch_size = batch_size
self.matching_function = matching_function
self._validation = []
def setup(self, stage: Optional[str] = None):
if stage in ("fit",None):
fp = Fileprocessor.from_name(self.name,self.files.train_clean,self.files.train_noisy,self.matching_function)
train_clean = os.path.join(self.root_dir,self.files.train_clean)
train_noisy = os.path.join(self.root_dir,self.files.train_noisy)
fp = Fileprocessor.from_name(self.name,train_clean,
train_noisy,self.sampling_rate,
self.matching_function)
self.train_data = fp.prepare_matching_dict()
fp = Fileprocessor.from_name(self.name,self.files.test_clean,self.files.test_noisy,self.matching_function)
val_clean = os.path.join(self.root_dir,self.files.test_clean)
val_noisy = os.path.join(self.root_dir,self.files.test_noisy)
fp = Fileprocessor.from_name(self.name,val_clean,
val_noisy,self.sampling_rate,
self.matching_function)
val_data = fp.prepare_matching_dict()
for item in val_data:
@ -116,7 +124,7 @@ class EnhancerDataset(TaskDataset):
def train__iter__(self):
rng = create_unique_rng(12) ##pass epoch number here
rng = create_unique_rng(self.model.current_epoch)
while True:
@ -141,7 +149,7 @@ class EnhancerDataset(TaskDataset):
return {"clean": clean_segment,"noisy":noisy_segment}
def train__len__(self):
return math.ceil(sum([file["duration"] for file in self.valid_files])/self.duration)
return math.ceil(sum([file["duration"] for file in self.train_data])/self.duration)
def val__len__(self):
return len(self._validation)