Warning: Using this requires >175gb of disk space (tensorflow will also generate examples that will take up space)
179 lines
6.7 KiB
Python
179 lines
6.7 KiB
Python
# coding=utf-8
|
|
# Copyright 2021 The TensorFlow Datasets Authors, Michael Pivato
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""Kitti dataset."""
|
|
|
|
import os
|
|
import zipfile
|
|
|
|
import numpy as np
|
|
import tensorflow as tf
|
|
import tensorflow_datasets.public_api as tfds
|
|
|
|
_CITATION = """\
|
|
@INPROCEEDINGS{Uhrig2017THREEDV,
|
|
author = {Jonas Uhrig and Nick Schneider and Lukas Schneider and Uwe Franke and Thomas Brox and Andreas Geiger},
|
|
title = {Sparsity Invariant CNNs},
|
|
booktitle = {International Conference on 3D Vision (3DV)},
|
|
year = {2017}
|
|
}
|
|
"""
|
|
_DESCRIPTION = """\
|
|
Kitti contains a suite of vision tasks built using an autonomous driving
|
|
platform. The full benchmark contains many tasks such as stereo, optical flow,
|
|
visual odometry, etc. This dataset contains the object detection dataset,
|
|
including the monocular images and bounding boxes. The dataset contains 7481
|
|
training images annotated with 3D bounding boxes. A full description of the
|
|
annotations can be found in the readme of the object development kit readme on
|
|
the Kitti homepage.
|
|
"""
|
|
_HOMEPAGE_URL = "http://www.cvlibs.net/datasets/kitti/"
|
|
_DATA_URL = "https://s3.eu-central-1.amazonaws.com/avg-kitti"
|
|
_RAW_URL = "https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data"
|
|
_DRIVE_URL = "https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/{}/{}.zip"
|
|
_DEPTH_FNAME = "data_depth_annotated.zip"
|
|
|
|
|
|
class KittiDepth(tfds.core.GeneratorBasedBuilder):
|
|
"""Kitti dataset."""
|
|
|
|
VERSION = tfds.core.Version("3.2.0")
|
|
SUPPORTED_VERSIONS = [
|
|
tfds.core.Version("3.1.0"),
|
|
]
|
|
RELEASE_NOTES = {
|
|
"3.2.0": "Initial Implementation."
|
|
}
|
|
|
|
def _info(self):
|
|
return tfds.core.DatasetInfo(
|
|
builder=self,
|
|
description=_DESCRIPTION,
|
|
features=tfds.features.FeaturesDict({
|
|
"image": tfds.features.Image(shape=(None, None, 3)),
|
|
"depth": tfds.features.Tensor(shape=(None, None), dtype=tf.float64),
|
|
}),
|
|
supervised_keys=("image", "depth"),
|
|
homepage=_HOMEPAGE_URL,
|
|
citation=_CITATION,
|
|
)
|
|
|
|
def _split_generators(self, dl_manager: tfds.download.DownloadManager):
|
|
depth_map_name = "depth_map"
|
|
filenames = {
|
|
depth_map_name: _DATA_URL + "/" + _DEPTH_FNAME,
|
|
}
|
|
files = dl_manager.download(filenames)
|
|
|
|
img_files = {}
|
|
|
|
# for fpath, fobj in dl_manager.iter_archive(files[depth_map_name]):
|
|
for fpath, fobj in iter_zip(files[depth_map_name]):
|
|
# Save all drives into a dict, then batch download them.
|
|
drive = fpath.split(os.path.sep)[1]
|
|
if img_files.get(drive) is None:
|
|
# Strip off the sync so that we can get the correct url, but still use the synced files
|
|
img_files[drive] = _DRIVE_URL.format(drive[0:-5], drive)
|
|
|
|
dl_img_files = dl_manager.download(img_files)
|
|
|
|
# TODO: Download test selections (no depths)
|
|
|
|
return [
|
|
tfds.core.SplitGenerator(
|
|
name=tfds.Split.TRAIN,
|
|
gen_kwargs={
|
|
"depths": iter_zip(files[depth_map_name]),
|
|
"depth_folder": "train",
|
|
"img_paths": dl_img_files,
|
|
}),
|
|
tfds.core.SplitGenerator(
|
|
name=tfds.Split.VALIDATION,
|
|
gen_kwargs={
|
|
"depths": iter_zip(files[depth_map_name]),
|
|
"depth_folder": "val",
|
|
"img_paths": dl_img_files,
|
|
}),
|
|
# TODO: Test? This would probably be the selection: data_depth_selection.zip
|
|
# It has no annotations though
|
|
# tfds.core.SplitGenerator(
|
|
# name=tfds.Split.TEST,
|
|
# gen_kwargs={
|
|
# }),
|
|
]
|
|
|
|
def _generate_examples(self, depths, depth_folder, img_paths):
|
|
"""Yields images and annotations.
|
|
|
|
Args:
|
|
depths: iterator over folder of downloaded depth maps
|
|
|
|
Yields:
|
|
A tuple containing the example's key, and the example.
|
|
"""
|
|
Image = tfds.core.lazy_imports.PIL_Image
|
|
|
|
for fpath, fobj in depths:
|
|
fpath_split = fpath.split(os.path.sep)
|
|
if fpath_split[0] != depth_folder:
|
|
continue
|
|
# Convert to numpy, apply transforms, yield the image.
|
|
# Same as the kitti devkit
|
|
depth_png = np.array(Image.open(fobj), dtype=int)
|
|
depth = depth_png.astype(np.float) / 256.
|
|
depth[depth_png == 0] = -1.
|
|
|
|
zip_iterator = iter_zip(img_paths[fpath.split(os.path.sep)[1]])
|
|
img = None
|
|
|
|
# TODO: Make this faster, currently takes 2.7s per example
|
|
# Tried keeping the iterator open to save time but this rarely ended up actually working
|
|
for img_path, img_obj in zip_iterator:
|
|
img_path_split = img_path.split(os.path.sep)
|
|
# We want image_2 for the right sight
|
|
if img_path_split[2] == fpath_split[-2] and img_path_split[-1] == fpath_split[-1]:
|
|
img = img_obj
|
|
break
|
|
yield "{}_{}_{}".format(fpath_split[1], fpath_split[4], fpath_split[5]), {
|
|
"image": np.array(Image.open(img)), "depth": depth}
|
|
|
|
|
|
#### WARNING: THESE ARE HERE IN PLACE OF THE OFFICIAL DownloadManager#iter_archive implementation as it's broken on windows zip files #####
|
|
#### See here, and if it's resolved this can be removed: https://github.com/tensorflow/tensorflow/issues/35630#issuecomment-600811558 #####
|
|
# Taken directly from extractor.py
|
|
def iter_zip(arch_f):
|
|
"""Iterate over zip archive."""
|
|
# This is modified to just open and iterate through the zip file
|
|
with open(arch_f, 'rb') as fobj:
|
|
z = zipfile.ZipFile(fobj)
|
|
for member in z.infolist():
|
|
extract_file = z.open(member)
|
|
if member.is_dir(): # Filter directories # pytype: disable=attribute-error
|
|
continue
|
|
path = _normpath(member.filename)
|
|
if not path:
|
|
continue
|
|
yield (path, extract_file)
|
|
|
|
|
|
def _normpath(path):
|
|
path = os.path.normpath(path)
|
|
if (path.startswith('.')
|
|
or os.path.isabs(path)
|
|
or path.endswith('~')
|
|
or os.path.basename(path).startswith('.')):
|
|
return None
|
|
return path
|