|
|
from torch.utils.data import Dataset, DataLoader |
|
|
from pycocotools.coco import COCO |
|
|
import os |
|
|
from PIL import Image |
|
|
from torchvision import transforms |
|
|
|
|
|
class CustomCocoDataset(Dataset): |
|
|
def __init__(self, json_file, img_folder, common_transform=None): |
|
|
self.coco = COCO(json_file) |
|
|
self.img_folder = img_folder |
|
|
self.ids = list(self.coco.imgToAnns.keys()) |
|
|
self.common_transform = common_transform |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.ids) |
|
|
|
|
|
def __getitem__(self, index): |
|
|
img_id = self.ids[index] |
|
|
img_info = self.coco.loadImgs(img_id)[0] |
|
|
path = img_info['file_name'] |
|
|
img_path = os.path.join(self.img_folder, path) |
|
|
image = Image.open(img_path).convert('RGB') |
|
|
|
|
|
|
|
|
i, j, h, w = transforms.RandomResizedCrop.get_params( |
|
|
image, scale=(0.9, 1.0), ratio=(1.0, 1.0)) |
|
|
|
|
|
cropped_image = transforms.functional.crop(image, i, j, h, w) |
|
|
|
|
|
|
|
|
jpg_image = transforms.functional.resize(cropped_image, 512, interpolation=transforms.InterpolationMode.BICUBIC) |
|
|
hint_image = transforms.functional.resize(cropped_image, 448, interpolation=transforms.InterpolationMode.BICUBIC) |
|
|
|
|
|
|
|
|
if self.common_transform is not None: |
|
|
jpg_image = self.common_transform(jpg_image) |
|
|
hint_image = self.common_transform(hint_image) |
|
|
|
|
|
ann_ids = self.coco.getAnnIds(imgIds=img_id) |
|
|
anns = self.coco.loadAnns(ann_ids) |
|
|
|
|
|
captions = [ann['caption'] for ann in anns] |
|
|
combined_caption = ' '.join(captions) |
|
|
|
|
|
return dict(jpg=jpg_image, txt=combined_caption, hint=hint_image) |
|
|
|
|
|
class CustomCocoDataset(Dataset): |
|
|
def __init__(self, json_file, img_folder, common_transform=None): |
|
|
self.coco = COCO(json_file) |
|
|
self.img_folder = img_folder |
|
|
self.ids = list(self.coco.imgToAnns.keys()) |
|
|
self.common_transform = common_transform |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.ids) |
|
|
|
|
|
def __getitem__(self, index): |
|
|
img_id = self.ids[index] |
|
|
img_info = self.coco.loadImgs(img_id)[0] |
|
|
path = img_info['file_name'] |
|
|
img_path = os.path.join(self.img_folder, path) |
|
|
image = Image.open(img_path).convert('RGB') |
|
|
|
|
|
|
|
|
i, j, h, w = transforms.RandomResizedCrop.get_params( |
|
|
image, scale=(0.95, 1.0), ratio=(1.0, 1.0)) |
|
|
|
|
|
cropped_image = transforms.functional.crop(image, i, j, h, w) |
|
|
|
|
|
|
|
|
jpg_image = transforms.functional.resize(cropped_image, 512, interpolation=transforms.InterpolationMode.BICUBIC) |
|
|
hint_image = transforms.functional.resize(cropped_image, 448, interpolation=transforms.InterpolationMode.BICUBIC) |
|
|
|
|
|
|
|
|
if self.common_transform is not None: |
|
|
jpg_image = self.common_transform(jpg_image) |
|
|
hint_image = self.common_transform(hint_image) |
|
|
|
|
|
ann_ids = self.coco.getAnnIds(imgIds=img_id) |
|
|
anns = self.coco.loadAnns(ann_ids) |
|
|
|
|
|
|
|
|
captions = [ann['caption'].replace('\n', ' ') for ann in anns] |
|
|
combined_caption = ' '.join(captions) |
|
|
|
|
|
return dict(jpg=jpg_image, txt=combined_caption, hint=hint_image) |
|
|
|
|
|
|
|
|
def main(): |
|
|
|
|
|
common_transform = transforms.Compose([ |
|
|
transforms.ToTensor(), |
|
|
|
|
|
]) |
|
|
|
|
|
|
|
|
dataset = CustomCocoDataset( |
|
|
json_file='/home/t2vg-a100-G4-1/projects/dataset/annotations/captions_train2017.json', |
|
|
img_folder='/home/t2vg-a100-G4-1/projects/dataset/train2017', |
|
|
common_transform=common_transform |
|
|
) |
|
|
|
|
|
|
|
|
dataloader = DataLoader(dataset, batch_size=1, shuffle=False) |
|
|
|
|
|
|
|
|
for batch in dataloader: |
|
|
jpg_image = batch['jpg'] |
|
|
|
|
|
|
|
|
print(f'JPG Image Min Value: {jpg_image.min().item()}') |
|
|
print(f'JPG Image Max Value: {jpg_image.max().item()}') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |