extract_data.py 2.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667
  1. import os
  2. import h5py
  3. import pickle
  4. import numpy as np
  5. import torch
  6. from torchvision import datasets, transforms
  7. import sys
  8. gen_path=sys.argv[1]
  9. monkey_name=sys.argv[2]
  10. roi_name=sys.argv[3]
  11. filename = gen_path + monkey_name + '/data_THINGS_array_'+ roi_name +'_v1.pkl'
  12. #data_path = gen_path + monkey_name +'/THINGS_exportMUA_array'+ roi_name + '.mat'
  13. data_path = gen_path + monkey_name +'/THINGS_Parma_A29_'+ roi_name + '.mat'
  14. imgs_path = gen_path + monkey_name +'/Things_s1/train/'
  15. val_imgs_path = gen_path + monkey_name + '/Things_s1/test/'
  16. #imgs_path = gen_path + 'THINGS_imgs/train/'
  17. #val_imgs_path = gen_path + 'THINGS_imgs/val/'
  18. data_dict = {}
  19. f = h5py.File(data_path,'r')
  20. for k, v in f.items():
  21. data_dict[k] = np.array(v)
  22. roi_data = data_dict['data'].squeeze()
  23. #train_idx = data_dict['train'].squeeze()-1
  24. #val_idx = data_dict['test'].squeeze()-1
  25. val_data = roi_data[:40]
  26. train_data = roi_data[40:]
  27. #val_data = data_dict['test_MUA'].squeeze()
  28. #train_data = data_dict['train_MUA'].squeeze()
  29. n_neurons = train_data.shape[1]
  30. del data_dict
  31. transform = transforms.Compose([transforms.Resize(224),
  32. transforms.ToTensor(),
  33. transforms.Normalize(mean=[0.485, 0.456, 0.406],
  34. std=[0.229, 0.224, 0.225])])
  35. dataset = datasets.ImageFolder(imgs_path, transform=transform)
  36. dataset_val = datasets.ImageFolder(val_imgs_path, transform=transform)
  37. debug = 0
  38. if debug:
  39. idxs_path = '/run/user/1000/gvfs/smb-share:server=vs03.herseninstituut.knaw.nl,share=vs03-vandc-3/THINGS/Passive_Fixation/monkeyF/THINGS_normMUA.mat'
  40. data_dict = {}
  41. f = h5py.File(idxs_path,'r')
  42. for k, v in f.items():
  43. data_dict[k] = np.array(v)
  44. idx_temp = data_dict['train_idx'].astype(int).squeeze() - 1
  45. temp_subset = torch.utils.data.Subset(dataset, idx_temp)
  46. dataloader = torch.utils.data.DataLoader(temp_subset, batch_size=train_data.shape[0], shuffle=False)
  47. else:
  48. dataloader = torch.utils.data.DataLoader(dataset, batch_size=train_data.shape[0], shuffle=False)
  49. img_data, junk = next(iter(dataloader))
  50. val_dataloader = torch.utils.data.DataLoader(dataset_val, batch_size=val_data.shape[0], shuffle=False)
  51. val_img_data, junk = next(iter(val_dataloader))
  52. output = {'img_data':img_data, 'val_img_data':val_img_data,
  53. 'train_data':train_data, 'val_data':val_data}
  54. f = open(filename,"wb")
  55. pickle.dump(output,f,protocol=4)
  56. f.close()