12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667 |
- import os
- import h5py
- import pickle
- import numpy as np
- import torch
- from torchvision import datasets, transforms
- import sys
- gen_path=sys.argv[1]
- monkey_name=sys.argv[2]
- roi_name=sys.argv[3]
- filename = gen_path + monkey_name + '/data_THINGS_array_'+ roi_name +'_v1.pkl'
- #data_path = gen_path + monkey_name +'/THINGS_exportMUA_array'+ roi_name + '.mat'
- data_path = gen_path + monkey_name +'/THINGS_Parma_A29_'+ roi_name + '.mat'
- imgs_path = gen_path + monkey_name +'/Things_s1/train/'
- val_imgs_path = gen_path + monkey_name + '/Things_s1/test/'
- #imgs_path = gen_path + 'THINGS_imgs/train/'
- #val_imgs_path = gen_path + 'THINGS_imgs/val/'
- data_dict = {}
- f = h5py.File(data_path,'r')
- for k, v in f.items():
- data_dict[k] = np.array(v)
-
- roi_data = data_dict['data'].squeeze()
- #train_idx = data_dict['train'].squeeze()-1
- #val_idx = data_dict['test'].squeeze()-1
- val_data = roi_data[:40]
- train_data = roi_data[40:]
- #val_data = data_dict['test_MUA'].squeeze()
- #train_data = data_dict['train_MUA'].squeeze()
- n_neurons = train_data.shape[1]
- del data_dict
- transform = transforms.Compose([transforms.Resize(224),
- transforms.ToTensor(),
- transforms.Normalize(mean=[0.485, 0.456, 0.406],
- std=[0.229, 0.224, 0.225])])
- dataset = datasets.ImageFolder(imgs_path, transform=transform)
- dataset_val = datasets.ImageFolder(val_imgs_path, transform=transform)
- debug = 0
- if debug:
- idxs_path = '/run/user/1000/gvfs/smb-share:server=vs03.herseninstituut.knaw.nl,share=vs03-vandc-3/THINGS/Passive_Fixation/monkeyF/THINGS_normMUA.mat'
- data_dict = {}
- f = h5py.File(idxs_path,'r')
- for k, v in f.items():
- data_dict[k] = np.array(v)
- idx_temp = data_dict['train_idx'].astype(int).squeeze() - 1
- temp_subset = torch.utils.data.Subset(dataset, idx_temp)
- dataloader = torch.utils.data.DataLoader(temp_subset, batch_size=train_data.shape[0], shuffle=False)
- else:
- dataloader = torch.utils.data.DataLoader(dataset, batch_size=train_data.shape[0], shuffle=False)
-
- img_data, junk = next(iter(dataloader))
- val_dataloader = torch.utils.data.DataLoader(dataset_val, batch_size=val_data.shape[0], shuffle=False)
- val_img_data, junk = next(iter(val_dataloader))
- output = {'img_data':img_data, 'val_img_data':val_img_data,
- 'train_data':train_data, 'val_data':val_data}
-
- f = open(filename,"wb")
- pickle.dump(output,f,protocol=4)
- f.close()
|