From 9ed1b9d3f717718e885e3bf05d1c48ab5a36bb4e Mon Sep 17 00:00:00 2001 From: shahules786 Date: Mon, 12 Sep 2022 10:54:52 +0530 Subject: [PATCH] debug dataset --- enhancer/data/dataset.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/enhancer/data/dataset.py b/enhancer/data/dataset.py index adcdea1..aa98219 100644 --- a/enhancer/data/dataset.py +++ b/enhancer/data/dataset.py @@ -3,7 +3,6 @@ from dataclasses import dataclass import glob import math import os -from typing_extensions import dataclass_transform import pytorch_lightning as pl from torch.utils.data import IterableDataset, DataLoader, Dataset import torch.nn.functional as F @@ -55,15 +54,24 @@ class TaskDataset(pl.LightningDataModule): self.sampling_rate = sampling_rate self.batch_size = batch_size self.matching_function = matching_function + self._validation = [] def setup(self, stage: Optional[str] = None): if stage in ("fit",None): - fp = Fileprocessor.from_name(self.name,self.files.train_clean,self.files.train_noisy,self.matching_function) + train_clean = os.path.join(self.root_dir,self.files.train_clean) + train_noisy = os.path.join(self.root_dir,self.files.train_noisy) + fp = Fileprocessor.from_name(self.name,train_clean, + train_noisy,self.sampling_rate, + self.matching_function) self.train_data = fp.prepare_matching_dict() - fp = Fileprocessor.from_name(self.name,self.files.test_clean,self.files.test_noisy,self.matching_function) + val_clean = os.path.join(self.root_dir,self.files.test_clean) + val_noisy = os.path.join(self.root_dir,self.files.test_noisy) + fp = Fileprocessor.from_name(self.name,val_clean, + val_noisy,self.sampling_rate, + self.matching_function) val_data = fp.prepare_matching_dict() for item in val_data: @@ -116,7 +124,7 @@ class EnhancerDataset(TaskDataset): def train__iter__(self): - rng = create_unique_rng(12) ##pass epoch number here + rng = create_unique_rng(self.model.current_epoch) while True: @@ -141,7 +149,7 @@ class EnhancerDataset(TaskDataset): return {"clean": clean_segment,"noisy":noisy_segment} def train__len__(self): - return math.ceil(sum([file["duration"] for file in self.valid_files])/self.duration) + return math.ceil(sum([file["duration"] for file in self.train_data])/self.duration) def val__len__(self): return len(self._validation)