import os import h5py import pickle import numpy as np import torch from torchvision import datasets, transforms import sys gen_path=sys.argv[1] monkey_name=sys.argv[2] roi_name=sys.argv[3] filename = gen_path + monkey_name + '/data_THINGS_array_'+ roi_name +'_v1.pkl' #data_path = gen_path + monkey_name +'/THINGS_exportMUA_array'+ roi_name + '.mat' data_path = gen_path + monkey_name +'/THINGS_Parma_A29_'+ roi_name + '.mat' imgs_path = gen_path + monkey_name +'/Things_s1/train/' val_imgs_path = gen_path + monkey_name + '/Things_s1/test/' #imgs_path = gen_path + 'THINGS_imgs/train/' #val_imgs_path = gen_path + 'THINGS_imgs/val/' data_dict = {} f = h5py.File(data_path,'r') for k, v in f.items(): data_dict[k] = np.array(v) roi_data = data_dict['data'].squeeze() #train_idx = data_dict['train'].squeeze()-1 #val_idx = data_dict['test'].squeeze()-1 val_data = roi_data[:40] train_data = roi_data[40:] #val_data = data_dict['test_MUA'].squeeze() #train_data = data_dict['train_MUA'].squeeze() n_neurons = train_data.shape[1] del data_dict transform = transforms.Compose([transforms.Resize(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) dataset = datasets.ImageFolder(imgs_path, transform=transform) dataset_val = datasets.ImageFolder(val_imgs_path, transform=transform) debug = 0 if debug: idxs_path = '/run/user/1000/gvfs/smb-share:server=vs03.herseninstituut.knaw.nl,share=vs03-vandc-3/THINGS/Passive_Fixation/monkeyF/THINGS_normMUA.mat' data_dict = {} f = h5py.File(idxs_path,'r') for k, v in f.items(): data_dict[k] = np.array(v) idx_temp = data_dict['train_idx'].astype(int).squeeze() - 1 temp_subset = torch.utils.data.Subset(dataset, idx_temp) dataloader = torch.utils.data.DataLoader(temp_subset, batch_size=train_data.shape[0], shuffle=False) else: dataloader = torch.utils.data.DataLoader(dataset, batch_size=train_data.shape[0], shuffle=False) img_data, junk = next(iter(dataloader)) val_dataloader = torch.utils.data.DataLoader(dataset_val, batch_size=val_data.shape[0], shuffle=False) val_img_data, junk = next(iter(val_dataloader)) output = {'img_data':img_data, 'val_img_data':val_img_data, 'train_data':train_data, 'val_data':val_data} f = open(filename,"wb") pickle.dump(output,f,protocol=4) f.close()