|
阅读本文前,请阅读相关说明,帮助您了解本系列文章
mmsegmentation的安装
mmsegmentation仓库地址:GitHub - open-mmlab/mmsegmentation: OpenMMLab Semantic Segmentation Toolbox and Benchmark.
具体安装细节可参考文档:Prerequisites - MMSegmentation 0.29.0 documentation
# 我的安装
# 创建环境
conda create -n mmseg python=3.7
conda activate mmseg
# 根据自身配置下载pytorch
pip install torch==1.7.1+cu101 torchvision==0.8.2+cu101 torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html
# 安装mmsegmentation相关
pip install -U openmim
mim install mmcv-full
# 可手动下载mmsegmentation仓库
cd mmsegmentation
pip install -v -e .
# 安装完成后可使用文档中的测试命令进行验证安装是否成功数据集准备
数据集准备可参考文档:Prepare datasets (本系列文章以ADE20K为例)
ADE20K数据集下载地址:http://data.csail.mit.edu/places/ADEchallenge/ADEChallengeData2016.zip
将下载好的数据集组织成文档中的目录组织形式
Pycharm Debug设置
- mmsegmentation中的train.py脚本在tools文件夹下,如果直接使用Pycharm进行debug train脚本,会将tools作为默认的工作目录,从而会导致某些路径找不到或者某些包无法导入等问题,所以需要在debug train脚本的时候将工作目录设置为mmdetection的根目录
- 在train.py中有一个必传参数config,用于指定具体的配置文件。本文以FCN为例,为此将其设置成了默认值。本文使用的配置文件为configs/fcn/fcn_r50-d8_512x512_80k_ade20k.py
def parse_args():
parser = argparse.ArgumentParser(description='Train a segmentor')
parser.add_argument('--config', default='configs/fcn/fcn_r50-d8_512x512_80k_ade20k.py', help='train config file path') # 必须指定的参数, 指定具体的配置文件关于配置文件的命名方式可参考文档:Tutorial 1: Learn about Configs
源码阅读
通过调试tools文件夹下的train.py脚本分析模型整个前向过程,本文尽量按照代码的执行顺序进行梳理
tools/train.py
主要用于传递和设置各种参数。与其相关联的文件为mmseg/apis/train.py,其中的train_segmentor函数在tools/train.py中被调用。
数据读取
配置文件configs/fcn/fcn_r50-d8_512x512_80k_ade20k.py中的配置信息如下,有关数据集的配置继承自_base_/datasets/ade20k.py
_base_ = [
'../_base_/models/fcn_r50-d8.py', '../_base_/datasets/ade20k.py',
'../_base_/default_runtime.py', '../_base_/schedules/schedule_80k.py'
]
model = dict(
decode_head=dict(num_classes=150), auxiliary_head=dict(num_classes=150))为此,有关数据集的配置信息存在于_base_/datasets/ade20k.py中,如下
# dataset settings
dataset_type = 'ADE20KDataset'
data_root = 'data/ade/ADEChallengeData2016'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
crop_size = (512, 512)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', reduce_zero_label=True),
dict(type='Resize', img_scale=(2048, 512), ratio_range=(0.5, 2.0)),
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
dict(type='RandomFlip', prob=0.5),
dict(type='PhotoMetricDistortion'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_semantic_seg']),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(2048, 512),
# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]
data = dict(
samples_per_gpu=4,
workers_per_gpu=4,
train=dict(
type=dataset_type,
data_root=data_root,
img_dir='images/training',
ann_dir='annotations/training',
pipeline=train_pipeline),
val=dict(
type=dataset_type,
data_root=data_root,
img_dir='images/validation',
ann_dir='annotations/validation',
pipeline=test_pipeline),
test=dict(
type=dataset_type,
data_root=data_root,
img_dir='images/validation',
ann_dir='annotations/validation',
pipeline=test_pipeline))
在数据集配置文件中,使用的数据集为ADE20K数据集,与之相关的类在mmseg/datasets/ade.py
该类为ADE20KDataset, 其继承自CustomDataset,此类位于mmseg/datasets/custom.py
CustomDataset继承自Pytorch的Dataset类.
继承体系: ADE20KDataset --> CustomDataset --> Dataset
_getitem_方法进行了一系列的数据处理, 该函数在CustomDataset类中
class CustomDataset(Dataset):
CLASSES = None
PALETTE = None
def __init__(self,
pipeline,
img_dir,
img_suffix='.jpg',
ann_dir=None,
seg_map_suffix='.png',
split=None,
data_root=None,
test_mode=False,
ignore_index=255,
reduce_zero_label=False,
classes=None,
palette=None,
gt_seg_map_loader_cfg=None,
file_client_args=dict(backend='disk')):
self.pipeline = Compose(pipeline) # 初始化为Compose类, pipeline即为配置文件中train_pipeline的一些列操作用于处理数据
self.img_dir = img_dir
self.img_suffix = img_suffix
self.ann_dir = ann_dir
self.seg_map_suffix = seg_map_suffix
self.split = split
self.data_root = data_root
self.test_mode = test_mode
self.ignore_index = ignore_index
self.reduce_zero_label = reduce_zero_label
self.label_map = None
self.CLASSES, self.PALETTE = self.get_classes_and_palette(
classes, palette)
self.gt_seg_map_loader = LoadAnnotations(
) if gt_seg_map_loader_cfg is None else LoadAnnotations(
**gt_seg_map_loader_cfg)
self.file_client_args = file_client_args
self.file_client = mmcv.FileClient.infer_client(self.file_client_args)
if test_mode:
assert self.CLASSES is not None, \
'`cls.CLASSES` or `classes` should be specified when testing'
# join paths if data_root is specified
if self.data_root is not None:
if not osp.isabs(self.img_dir):
self.img_dir = osp.join(self.data_root, self.img_dir)
if not (self.ann_dir is None or osp.isabs(self.ann_dir)):
self.ann_dir = osp.join(self.data_root, self.ann_dir)
if not (self.split is None or osp.isabs(self.split)):
self.split = osp.join(self.data_root, self.split)
# load annotations
self.img_infos = self.load_annotations(self.img_dir, self.img_suffix,
self.ann_dir,
self.seg_map_suffix, self.split)
def __len__(self):
return len(self.img_infos)
def __getitem__(self, idx):
if self.test_mode:
return self.prepare_test_img(idx)
else: # 满足
return self.prepare_train_img(idx)
def prepare_train_img(self, idx):
img_info = self.img_infos[idx] # 获取图像信息
ann_info = self.get_ann_info(idx) # 获取图像标注信息
results = dict(img_info=img_info, ann_info=ann_info) # 初始化result字典 后续会被不断更新
self.pre_pipeline(results) # 向字典中添加一系列信息
return self.pipeline(results) # 调用Compose类的__call__函数进行数据处理
def pre_pipeline(self, results): # 向字典中添加一系列信息
results['seg_fields'] = []
results['img_prefix'] = self.img_dir
results['seg_prefix'] = self.ann_dir
if self.custom_classes: # 不满足
results['label_map'] = self.label_map
Compose类位于mmseg/datasets/pipelines/compose.py,如下
class Compose(object):
def __init__(self, transforms):
assert isinstance(transforms, collections.abc.Sequence)
self.transforms = [] # List, 每个元素为数据集配置文件中train_pipeline的各项操作
for transform in transforms:
if isinstance(transform, dict):
transform = build_from_cfg(transform, PIPELINES)
self.transforms.append(transform)
elif callable(transform):
self.transforms.append(transform)
else:
raise TypeError('transform must be callable or a dict')
def __call__(self, data):
for t in self.transforms: # 遍历每个数据操作
# 执行t数据操作, 分别为LoadImageFromFile,LoadAnnotations,Resize,RandomCrop,RandomFlip,PhotoMetricDistortion, Normalize,Pad,DefaultFormatBundle,Collect
# 分别调用各个类的__call__函数, 更新data中的信息
data = t(data)
if data is None: # 一般不满足
return None
return dataLoadImageFromFile类位于mmseg/datasets/pipelines/loading.py,如下
class LoadImageFromFile(object):
def __init__(self,
to_float32=False,
color_type='color',
file_client_args=dict(backend='disk'),
imdecode_backend='cv2'):
self.to_float32 = to_float32
self.color_type = color_type
self.file_client_args = file_client_args.copy()
self.file_client = None
self.imdecode_backend = imdecode_backend
def __call__(self, results):
if self.file_client is None: # 满足
self.file_client = mmcv.FileClient(**self.file_client_args)
if results.get('img_prefix') is not None: # 满足
filename = osp.join(results['img_prefix'],
results['img_info']['filename']) # 组合出图像的路径
else:
filename = results['img_info']['filename']
img_bytes = self.file_client.get(filename)
img = mmcv.imfrombytes(
img_bytes, flag=self.color_type, backend=self.imdecode_backend) # 读取图像, ndarray类型 shape [H, W, C]
if self.to_float32:
img = img.astype(np.float32)
results['filename'] = filename # 图像路径
results['ori_filename'] = results['img_info']['filename'] # 图像名称
results['img'] = img # 读入的图像
results['img_shape'] = img.shape # 图像尺寸,后续会被更新
results['ori_shape'] = img.shape # 图像原始尺寸
# Set initial values for default meta_keys 初始化一些key,后续会被更新
results['pad_shape'] = img.shape # 图像尺寸,后续会被更新
results['scale_factor'] = 1.0 # 图像scale_factor,后续会被更新
num_channels = 1 if len(img.shape) < 3 else img.shape[2]
results[&#39;img_norm_cfg&#39;] = dict( # norm参数,后续会被更新
mean=np.zeros(num_channels, dtype=np.float32),
std=np.ones(num_channels, dtype=np.float32),
to_rgb=False)
return resultsLoadAnnotations类位于mmseg/datasets/pipelines/loading.py,如下
class LoadAnnotations(object):
def __init__(self,
reduce_zero_label=False,
file_client_args=dict(backend=&#39;disk&#39;),
imdecode_backend=&#39;pillow&#39;):
self.reduce_zero_label = reduce_zero_label
self.file_client_args = file_client_args.copy()
self.file_client = None
self.imdecode_backend = imdecode_backend
def __call__(self, results):
if self.file_client is None: # 满足
self.file_client = mmcv.FileClient(**self.file_client_args)
if results.get(&#39;seg_prefix&#39;, None) is not None: # 满足
filename = osp.join(results[&#39;seg_prefix&#39;],
results[&#39;ann_info&#39;][&#39;seg_map&#39;]) # 组合出标注的路径
else:
filename = results[&#39;ann_info&#39;][&#39;seg_map&#39;]
img_bytes = self.file_client.get(filename)
gt_semantic_seg = mmcv.imfrombytes(
img_bytes, flag=&#39;unchanged&#39;,
backend=self.imdecode_backend).squeeze().astype(np.uint8) # 读取标注,ndarray类型 shape [H, W]
# modify if custom classes
if results.get(&#39;label_map&#39;, None) is not None: # 不满足
# Add deep copy to solve bug of repeatedly
# replace `gt_semantic_seg`, which is reported in
# https://github.com/open-mmlab/mmsegmentation/pull/1445/
gt_semantic_seg_copy = gt_semantic_seg.copy()
for old_id, new_id in results[&#39;label_map&#39;].items():
gt_semantic_seg[gt_semantic_seg_copy == old_id] = new_id
# reduce zero_label
if self.reduce_zero_label: # 该参数对于不同数据集会有不同的设置 对于ADE20K默认为True,因为0代表背景,但是不包含在ADE20K的150个类别中
# avoid using underflow conversion
gt_semantic_seg[gt_semantic_seg == 0] = 255 # reduce_zero_label转换
gt_semantic_seg = gt_semantic_seg - 1
gt_semantic_seg[gt_semantic_seg == 254] = 255
results[&#39;gt_semantic_seg&#39;] = gt_semantic_seg # 添加标注
results[&#39;seg_fields&#39;].append(&#39;gt_semantic_seg&#39;)
return resultsResize类位于mmseg/datasets/pipelines/transforms.py,如下
class Resize(object):
def __init__(self,
img_scale=None,
multiscale_mode=&#39;range&#39;,
ratio_range=None,
keep_ratio=True,
min_size=None):
if img_scale is None:
self.img_scale = None
else:
if isinstance(img_scale, list):
self.img_scale = img_scale
else:
self.img_scale = [img_scale]
assert mmcv.is_list_of(self.img_scale, tuple)
if ratio_range is not None:
# mode 1: given img_scale=None and a range of image ratio
# mode 2: given a scale and a range of image ratio
assert self.img_scale is None or len(self.img_scale) == 1
else:
# mode 3 and 4: given multiple scales or a range of scales
assert multiscale_mode in [&#39;value&#39;, &#39;range&#39;]
self.multiscale_mode = multiscale_mode
self.ratio_range = ratio_range
self.keep_ratio = keep_ratio
self.min_size = min_size
@staticmethod
def random_sample_ratio(img_scale, ratio_range):
assert isinstance(img_scale, tuple) and len(img_scale) == 2
min_ratio, max_ratio = ratio_range
assert min_ratio <= max_ratio
ratio = np.random.random_sample() * (max_ratio - min_ratio) + min_ratio # 在ratio_range中随机选择一个ratio
scale = int(img_scale[0] * ratio), int(img_scale[1] * ratio) # 获得目标scale尺度范围
return scale, None
def _random_scale(self, results):
if self.ratio_range is not None: # 满足
if self.img_scale is None:
h, w = results[&#39;img&#39;].shape[:2]
scale, scale_idx = self.random_sample_ratio((w, h),
self.ratio_range)
else: # 满足
scale, scale_idx = self.random_sample_ratio(
self.img_scale[0], self.ratio_range)
elif len(self.img_scale) == 1:
scale, scale_idx = self.img_scale[0], 0
elif self.multiscale_mode == &#39;range&#39;:
scale, scale_idx = self.random_sample(self.img_scale)
elif self.multiscale_mode == &#39;value&#39;:
scale, scale_idx = self.random_select(self.img_scale)
else:
raise NotImplementedError
results[&#39;scale&#39;] = scale # 添加scale信息
results[&#39;scale_idx&#39;] = scale_idx # None
def _resize_img(self, results):
if self.keep_ratio: # 满足
if self.min_size is not None: # 不满足
# TODO: Now &#39;min_size&#39; is an &#39;int&#39; which means the minimum
# shape of images is (min_size, min_size, 3). &#39;min_size&#39;
# with tuple type will be supported, i.e. the width and
# height are not equal.
if min(results[&#39;scale&#39;]) < self.min_size:
new_short = self.min_size
else:
new_short = min(results[&#39;scale&#39;])
h, w = results[&#39;img&#39;].shape[:2]
if h > w:
new_h, new_w = new_short * h / w, new_short
else:
new_h, new_w = new_short, new_short * w / h
results[&#39;scale&#39;] = (new_h, new_w)
img, scale_factor = mmcv.imrescale( # 图像将被resize到scale内且尽可能的大
results[&#39;img&#39;], results[&#39;scale&#39;], return_scale=True)
# the w_scale and h_scale has minor difference
# a real fix should be done in the mmcv.imrescale in the future
new_h, new_w = img.shape[:2] # resize后图像的高 宽
h, w = results[&#39;img&#39;].shape[:2] # resize前图像的高 宽
w_scale = new_w / w # 宽的放缩尺度
h_scale = new_h / h # 高的放缩尺度
else:
img, w_scale, h_scale = mmcv.imresize(
results[&#39;img&#39;], results[&#39;scale&#39;], return_scale=True)
scale_factor = np.array([w_scale, h_scale, w_scale, h_scale],
dtype=np.float32)
results[&#39;img&#39;] = img # 更新图像
results[&#39;img_shape&#39;] = img.shape # resize后图像尺寸
results[&#39;pad_shape&#39;] = img.shape # result添加pad_shape,pad后图像尺寸, 后续会被更新
results[&#39;scale_factor&#39;] = scale_factor # result添加scale_factor, ndarray, shape [4,]
results[&#39;keep_ratio&#39;] = self.keep_ratio # result添加keep_ratio
def _resize_seg(self, results):
&#34;&#34;&#34;Resize semantic segmentation map with ``results[&#39;scale&#39;]``.&#34;&#34;&#34;
for key in results.get(&#39;seg_fields&#39;, []):
if self.keep_ratio: # 满足
gt_seg = mmcv.imrescale( # 标注将被resize到scale内且尽可能的大
results[key], results[&#39;scale&#39;], interpolation=&#39;nearest&#39;)
else:
gt_seg = mmcv.imresize(
results[key], results[&#39;scale&#39;], interpolation=&#39;nearest&#39;)
results[key] = gt_seg # 更新标注
def __call__(self, results):
if &#39;scale&#39; not in results: # 满足
self._random_scale(results)
self._resize_img(results) # resize图像
self._resize_seg(results) # resize标注
return results
RandomCrop类位于mmseg/datasets/pipelines/transforms.py,如下
class RandomCrop(object):
def __init__(self, crop_size, cat_max_ratio=1., ignore_index=255):
assert crop_size[0] > 0 and crop_size[1] > 0
self.crop_size = crop_size
self.cat_max_ratio = cat_max_ratio
self.ignore_index = ignore_index
def get_crop_bbox(self, img):
margin_h = max(img.shape[0] - self.crop_size[0], 0)
margin_w = max(img.shape[1] - self.crop_size[1], 0)
offset_h = np.random.randint(0, margin_h + 1) # 随机选取crop的左上初始点
offset_w = np.random.randint(0, margin_w + 1)
crop_y1, crop_y2 = offset_h, offset_h + self.crop_size[0] # 获取裁减区域
crop_x1, crop_x2 = offset_w, offset_w + self.crop_size[1]
return crop_y1, crop_y2, crop_x1, crop_x2
def crop(self, img, crop_bbox):
crop_y1, crop_y2, crop_x1, crop_x2 = crop_bbox
img = img[crop_y1:crop_y2, crop_x1:crop_x2, ...]
return img
def __call__(self, results):
img = results[&#39;img&#39;]
crop_bbox = self.get_crop_bbox(img) # 获取crop区域
if self.cat_max_ratio < 1.: # 满足
# Repeat 10 times
for _ in range(10):
seg_temp = self.crop(results[&#39;gt_semantic_seg&#39;], crop_bbox) # crop标注
labels, cnt = np.unique(seg_temp, return_counts=True)
cnt = cnt[labels != self.ignore_index]
# len(cnt) > 1确保crop区域内存在前景
# np.max(cnt) / np.sum(cnt) < self.cat_max_ratio 确保前景中的某个类别的像素占所有前景像素比例小于cat_max_ratio
# 为了保证crop后前景类别区域尽可能多
if len(cnt) > 1 and np.max(cnt) / np.sum(cnt) < self.cat_max_ratio:
break
crop_bbox = self.get_crop_bbox(img) # 重新获取crop区域
# crop the image
img = self.crop(img, crop_bbox) # crop图像
img_shape = img.shape
results[&#39;img&#39;] = img # 更新图像
results[&#39;img_shape&#39;] = img_shape # 更新为crop后的图像尺寸
# crop semantic seg
for key in results.get(&#39;seg_fields&#39;, []):
results[key] = self.crop(results[key], crop_bbox) # crop标注
return resultsRandomFlip类位于mmseg/datasets/pipelines/transforms.py,如下
class RandomFlip(object):
@deprecated_api_warning({&#39;flip_ratio&#39;: &#39;prob&#39;}, cls_name=&#39;RandomFlip&#39;)
def __init__(self, prob=None, direction=&#39;horizontal&#39;):
self.prob = prob
self.direction = direction
if prob is not None:
assert prob >= 0 and prob <= 1
assert direction in [&#39;horizontal&#39;, &#39;vertical&#39;]
def __call__(self, results):
if &#39;flip&#39; not in results: # 满足
flip = True if np.random.rand() < self.prob else False
results[&#39;flip&#39;] = flip # 添加是否flip信息
if &#39;flip_direction&#39; not in results:
results[&#39;flip_direction&#39;] = self.direction # 添加flip方向信息 默认为horizontal
if results[&#39;flip&#39;]:
# flip image
results[&#39;img&#39;] = mmcv.imflip( # flip图像
results[&#39;img&#39;], direction=results[&#39;flip_direction&#39;])
# flip segs
for key in results.get(&#39;seg_fields&#39;, []):
# use copy() to make numpy stride positive
results[key] = mmcv.imflip( # flip标注
results[key], direction=results[&#39;flip_direction&#39;]).copy()
return results
PhotoMetricDistortion类位于mmseg/datasets/pipelines/transforms.py,如下
class PhotoMetricDistortion(object):
def __init__(self,
brightness_delta=32,
contrast_range=(0.5, 1.5),
saturation_range=(0.5, 1.5),
hue_delta=18):
self.brightness_delta = brightness_delta
self.contrast_lower, self.contrast_upper = contrast_range
self.saturation_lower, self.saturation_upper = saturation_range
self.hue_delta = hue_delta
def convert(self, img, alpha=1, beta=0):
img = img.astype(np.float32) * alpha + beta
img = np.clip(img, 0, 255)
return img.astype(np.uint8)
def brightness(self, img):
if random.randint(2): # 返回0或1
return self.convert(
img,
beta=random.uniform(-self.brightness_delta, # 调整亮度
self.brightness_delta))
return img
def contrast(self, img):
if random.randint(2):
return self.convert(
img,
alpha=random.uniform(self.contrast_lower, self.contrast_upper)) # 调整对比度
return img
def saturation(self, img):
if random.randint(2):
img = mmcv.bgr2hsv(img)
img[:, :, 1] = self.convert(
img[:, :, 1],
alpha=random.uniform(self.saturation_lower,
self.saturation_upper)) # 调整饱和度
img = mmcv.hsv2bgr(img)
return img
def hue(self, img):
if random.randint(2):
img = mmcv.bgr2hsv(img)
img[:, :,
0] = (img[:, :, 0].astype(int) +
random.randint(-self.hue_delta, self.hue_delta)) % 180 # 调整色相
img = mmcv.hsv2bgr(img)
return img
def __call__(self, results):
img = results[&#39;img&#39;] # 获取图像
# random brightness
img = self.brightness(img) # 调整亮度
# mode == 0 --> do random contrast first
# mode == 1 --> do random contrast last
mode = random.randint(2)
if mode == 1:
img = self.contrast(img) # 调整对比度
# random saturation
img = self.saturation(img) # 调整饱和度
# random hue
img = self.hue(img) # 调整色相
# random contrast
if mode == 0:
img = self.contrast(img)
results[&#39;img&#39;] = img
return resultsNormalize类位于mmseg/datasets/pipelines/transforms.py,如下
class Normalize(object):
def __init__(self, mean, std, to_rgb=True):
self.mean = np.array(mean, dtype=np.float32)
self.std = np.array(std, dtype=np.float32)
self.to_rgb = to_rgb
def __call__(self, results):
results[&#39;img&#39;] = mmcv.imnormalize(results[&#39;img&#39;], self.mean, self.std,
self.to_rgb) # 对图像进行normalize
results[&#39;img_norm_cfg&#39;] = dict( # 更新img_norm_cfg信息
mean=self.mean, std=self.std, to_rgb=self.to_rgb)
return resultsPad类位于
class Pad(object):
def __init__(self,
size=None,
size_divisor=None,
pad_val=0,
seg_pad_val=255):
self.size = size
self.size_divisor = size_divisor
self.pad_val = pad_val
self.seg_pad_val = seg_pad_val
# only one of size and size_divisor should be valid
assert size is not None or size_divisor is not None
assert size is None or size_divisor is None
def _pad_img(self, results):
if self.size is not None: # 满足
padded_img = mmcv.impad( # 对img pad 0,pad到crop size
results[&#39;img&#39;], shape=self.size, pad_val=self.pad_val)
elif self.size_divisor is not None:
padded_img = mmcv.impad_to_multiple(
results[&#39;img&#39;], self.size_divisor, pad_val=self.pad_val)
results[&#39;img&#39;] = padded_img # 更新图像
results[&#39;pad_shape&#39;] = padded_img.shape # pad后图像的shape
results[&#39;pad_fixed_size&#39;] = self.size # pad到固定尺寸尺寸
results[&#39;pad_size_divisor&#39;] = self.size_divisor # None
def _pad_seg(self, results):
for key in results.get(&#39;seg_fields&#39;, []):
results[key] = mmcv.impad( # 对标注pad 255,pad到crop size
results[key],
shape=results[&#39;pad_shape&#39;][:2],
pad_val=self.seg_pad_val)
def __call__(self, results):
self._pad_img(results) # pad图像
self._pad_seg(results) # pad标注
return resultsDefaultFormatBundle类位于mmseg/datasets/pipelines/formatting.py,如下
class DefaultFormatBundle(object):
def __call__(self, results):
if &#39;img&#39; in results: # 满足
img = results[&#39;img&#39;] # 获取图像
if len(img.shape) < 3:
img = np.expand_dims(img, -1)
img = np.ascontiguousarray(img.transpose(2, 0, 1)) # [H,W,C]--> [C,H,W]
results[&#39;img&#39;] = DC(to_tensor(img), stack=True) # 转为tensor, 并封装为DataContainer
if &#39;gt_semantic_seg&#39; in results:
# convert to long
results[&#39;gt_semantic_seg&#39;] = DC(
to_tensor(results[&#39;gt_semantic_seg&#39;][None,
...].astype(np.int64)),
stack=True) # 新增channel维度,转为tensor, 并封装为DataContainer
return results
Collect类位于mmseg/datasets/pipelines/formatting.py,如下
class Collect(object):
def __init__(self,
keys,
meta_keys=(&#39;filename&#39;, &#39;ori_filename&#39;, &#39;ori_shape&#39;,
&#39;img_shape&#39;, &#39;pad_shape&#39;, &#39;scale_factor&#39;, &#39;flip&#39;,
&#39;flip_direction&#39;, &#39;img_norm_cfg&#39;)):
self.keys = keys
self.meta_keys = meta_keys
def __call__(self, results):
data = {}
img_meta = {} # 用于存储图像一些基本信息
for key in self.meta_keys:
img_meta[key] = results[key]
data[&#39;img_metas&#39;] = DC(img_meta, cpu_only=True) # 被封装进DataContainer
for key in self.keys:
data[key] = results[key] # 获取与任务相关的信息
return data经过上述一系列处理 ,__getitem__方法最终输出一个字典。Dict[str,DataContainer]内部含有3个元素img_metas:Dict[str, *]、img:Tensor [3, H, W]和gt_semantic_seg:Tensor [1,H,W] |
|