创建类market1501
类内公共属性: dataset_dir = ‘market1501’
类的初始化:data = Market1501(root = ‘G:\data’)
内类属性:market1501地址,训练集,测试集,gallary地址

 def __init__(self,root = 'data', **kwargs):self.dataset_dir = osp.join(root, self.dataset_dir)self.train_dir  = osp.join(self.dataset_dir, 'bounding_box_train')self.query_dir = osp.join(self.dataset_dir, 'query')self.gallery_dir = osp.join(self.dataset_dir, 'bounding_box_test')

实例化之后查看一下相应的train,test,gallary是否存在,如果不存在则提示

self._check_before_run()

重点是类方法_process_dir,传入装着图片的文件夹地址,和一个标志位relabel,为什么需要这个标志位呢。(就相当于一个开关,如果我们使用训练集数据时,则打开这个开关让算出来label,如果测试)
首先读取全部的问价路径并且存储到img_paths列表,使用一个set()容器用来装.这样做的好处是利用集合这种数据结构完成元素的自动去重,自动化得到有多个类别。
下一步循环完成提取,最后得到一个有着所有种类标签的集合比如(cat,dog,frog,…)
下一步,种类名称和他所对应的标签

pid2label = {pid:label for label, pid in enumerate(pid_container)}

得到一个字典pidlabel,键为类别名称(cat…),对应的键值为(他的标签)
然后我们创建dataset列表

# -*- encoding: utf-8 -*-
"""
@File    : data_manager.py
@Time    : 2021-05-07 11:25
@Author  : XD
@Email   : gudianpai@qq.com
@Software: PyCharm
"""
import os
import os.path as osp
import refrom utils import mkdir_if_missing, write_json , read_jsonfrom IPython import embedimport globclass Market1501(object):"""Market1501Reference:Zheng et al. Scalable Person Re-identification: A Benchmark. ICCV 2015.URL: http://www.liangzheng.org/Project/project_reid.htmlDataset statistics:# identities: 1501 (+1 for background)# images: 12936 (train) + 3368 (query) + 15913 (gallery)"""dataset_dir = 'market1501'def __init__(self,root = 'data', **kwargs):self.dataset_dir = osp.join(root, self.dataset_dir)self.train_dir  = osp.join(self.dataset_dir, 'bounding_box_train')self.query_dir = osp.join(self.dataset_dir, 'query')self.gallery_dir = osp.join(self.dataset_dir, 'bounding_box_test')self._check_before_run()#data_dir, ID, CAMID ,NUMtrain, num_train_pids, num_train_imgs = self._process_dir(self.train_dir)query, num_query_pids, num_query_imgs = self._process_dir(self.query_dir, relabel=False)gallery, num_gallery_pids, num_gallery_imgs = self._process_dir(self.gallery_dir, relabel=False)num_total_pids = num_train_pids + num_query_pidsnum_total_imgs = num_train_imgs + num_query_imgs + num_gallery_imgsprint("=> Market1501 loaded")print("Dataset statistics:")print("  ------------------------------")print("  subset   | # ids | # images")print("  ------------------------------")print("  train    | {:5d} | {:8d}".format(num_train_pids, num_train_imgs))print("  query    | {:5d} | {:8d}".format(num_query_pids, num_query_imgs))print("  gallery  | {:5d} | {:8d}".format(num_gallery_pids, num_gallery_imgs))print("  ------------------------------")print("  total    | {:5d} | {:8d}".format(num_total_pids, num_total_imgs))print("  ------------------------------")self.train = trainself.query = queryself.gallery = galleryself.num_train_pids = num_train_pidsself.num_query_pids = num_query_pidsself.num_gallery_pids = num_gallery_pidsdef _check_before_run(self):"""Check if all files are available before going deeper"""if not osp.exists(self.dataset_dir):raise RuntimeError("{} is not available".format(self.dataset_dir))if not osp.exists(self.train_dir):raise RuntimeError("{} is not available".format(self.train_dir))if not osp.exists(self.query_dir):raise RuntimeError("{} is not available".format(self.query_dir))if not osp.exists(self.gallery_dir):raise RuntimeError("{} is not available".format(self.gallery_dir))def _process_dir(self, dir_path, relabel = False):img_paths = glob.glob(osp.join(dir_path, '*.jpg'))pattern = re.compile(r'([-\d]+)_c(\d)')pid_container = set()for img_path in img_paths:pid, _ = map(int, pattern.search(img_path).groups())if pid == -1: continuepid_container.add(pid)pid2label = {pid:label for label, pid in enumerate(pid_container)}dataset = []for img_path in img_paths:pid, camid = map(int, pattern.search(img_path).groups())if pid == -1: continueassert 0 <= pid <= 1501assert 1 <= camid <= 6camid += -1if relabel:pid = pid2label[pid]dataset.append((img_path, pid ,camid))num_pids = len(pid_container)num_imgs = len(img_paths)return dataset, num_pids ,num_imgsif __name__ == '__main__':data = Market1501(root = 'G:\data')
D:\ANACONDA\envs\pytorch_gpu\python.exe G:/图像检索文章/深度哈希/度量学习/proj_Reid/util/data_manager.py
=> Market1501 loaded
Dataset statistics:------------------------------subset   | # ids | # images------------------------------train    |   751 |    12936query    |   750 |     3368gallery  |   751 |    19732------------------------------total    |  1501 |    36036------------------------------Process finished with exit code 0