Merge branch 'kitti_depth_dataset' into 'main'
Kitti depth dataset See merge request vato007/fast-depth-tf!2
This commit is contained in:
@@ -87,13 +87,13 @@ def dense_nnconv5(size, weights=None, shape=(224, 224, 3), half_features=True):
|
||||
return keras.Model(inputs=input, outputs=decoder, name="fast_dense_depth")
|
||||
|
||||
|
||||
def load_nyu():
|
||||
def load_nyu(download_dir='../nyu'):
|
||||
"""
|
||||
Load the nyu_v2 dataset train split. Will be downloaded to ../nyu
|
||||
:return: nyu_v2 dataset builder
|
||||
"""
|
||||
builder = tfds.builder('nyu_depth_v2')
|
||||
builder.download_and_prepare(download_dir='../nyu')
|
||||
builder.download_and_prepare(download_dir=download_dir)
|
||||
return builder \
|
||||
.as_dataset(split='train', shuffle_files=True) \
|
||||
.shuffle(buffer_size=1024) \
|
||||
@@ -101,13 +101,13 @@ def load_nyu():
|
||||
.map(lambda x: fd.crop_and_resize(x))
|
||||
|
||||
|
||||
def load_nyu_evaluate():
|
||||
def load_nyu_evaluate(download_dir='../nyu'):
|
||||
"""
|
||||
Load the nyu_v2 dataset validation split. Will be downloaded to ../nyu
|
||||
:return: nyu_v2 dataset builder
|
||||
"""
|
||||
builder = tfds.builder('nyu_depth_v2')
|
||||
builder.download_and_prepare(download_dir='../nyu')
|
||||
builder.download_and_prepare(download_dir=download_dir)
|
||||
return builder.as_dataset(split='validation').batch(1).map(lambda x: fd.crop_and_resize(x))
|
||||
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import tensorflow as tf
|
||||
import tensorflow.keras as keras
|
||||
import tensorflow_datasets as tfds
|
||||
# Needed for the kitti dataset, don't delete
|
||||
|
||||
"""
|
||||
Unofficial tensorflow keras implementation of FastDepth (mobilenet_nnconv5).
|
||||
@@ -162,11 +163,15 @@ def load_model(file):
|
||||
|
||||
def crop_and_resize(x):
|
||||
shape = tf.shape(x['depth'])
|
||||
img_shape = tf.shape(x['image'])
|
||||
# Ensure we get a square for when we resize is later.
|
||||
# For horizontal images this is basically just cropping the sides off
|
||||
center_shape = min(shape[1], shape[2], img_shape[1], img_shape[2])
|
||||
|
||||
def layer():
|
||||
return keras.Sequential([
|
||||
keras.layers.experimental.preprocessing.CenterCrop(
|
||||
shape[1], shape[2]),
|
||||
center_shape, center_shape),
|
||||
keras.layers.experimental.preprocessing.Resizing(
|
||||
224, 224, interpolation='nearest')
|
||||
])
|
||||
@@ -175,13 +180,13 @@ def crop_and_resize(x):
|
||||
return layer()(x['image']), layer()(tf.reshape(x['depth'], [shape[0], shape[1], shape[2], 1]))
|
||||
|
||||
|
||||
def load_nyu():
|
||||
def load_nyu(download_dir='../nyu'):
|
||||
"""
|
||||
Load the nyu_v2 dataset train split. Will be downloaded to ../nyu
|
||||
:return: nyu_v2 dataset builder
|
||||
"""
|
||||
builder = tfds.builder('nyu_depth_v2')
|
||||
builder.download_and_prepare(download_dir='../nyu')
|
||||
builder.download_and_prepare(download_dir=download_dir)
|
||||
return builder \
|
||||
.as_dataset(split='train', shuffle_files=True) \
|
||||
.shuffle(buffer_size=1024) \
|
||||
@@ -189,16 +194,22 @@ def load_nyu():
|
||||
.map(lambda x: crop_and_resize(x))
|
||||
|
||||
|
||||
def load_nyu_evaluate():
|
||||
def load_nyu_evaluate(download_dir='../nyu'):
|
||||
"""
|
||||
Load the nyu_v2 dataset validation split. Will be downloaded to ../nyu
|
||||
:return: nyu_v2 dataset builder
|
||||
"""
|
||||
builder = tfds.builder('nyu_depth_v2')
|
||||
builder.download_and_prepare(download_dir='../nyu')
|
||||
builder.download_and_prepare(download_dir=download_dir)
|
||||
return builder.as_dataset(split='validation').batch(1).map(lambda x: crop_and_resize(x))
|
||||
|
||||
|
||||
def load_kitti(download_dir='../kitti'):
|
||||
ds = tfds.builder('kitti_depth')
|
||||
ds.download_and_prepare(download_dir=download_dir)
|
||||
return ds.as_dataset(tfds.Split.TRAIN).batch(8).map(lambda x: crop_and_resize(x))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
model = mobilenet_nnconv5()
|
||||
model.summary()
|
||||
|
||||
178
kitti_depth/kitti_depth.py
Normal file
178
kitti_depth/kitti_depth.py
Normal file
@@ -0,0 +1,178 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2021 The TensorFlow Datasets Authors, Michael Pivato
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Kitti dataset."""
|
||||
|
||||
import os
|
||||
import zipfile
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
import tensorflow_datasets.public_api as tfds
|
||||
|
||||
_CITATION = """\
|
||||
@INPROCEEDINGS{Uhrig2017THREEDV,
|
||||
author = {Jonas Uhrig and Nick Schneider and Lukas Schneider and Uwe Franke and Thomas Brox and Andreas Geiger},
|
||||
title = {Sparsity Invariant CNNs},
|
||||
booktitle = {International Conference on 3D Vision (3DV)},
|
||||
year = {2017}
|
||||
}
|
||||
"""
|
||||
_DESCRIPTION = """\
|
||||
Kitti contains a suite of vision tasks built using an autonomous driving
|
||||
platform. The full benchmark contains many tasks such as stereo, optical flow,
|
||||
visual odometry, etc. This dataset contains the object detection dataset,
|
||||
including the monocular images and bounding boxes. The dataset contains 7481
|
||||
training images annotated with 3D bounding boxes. A full description of the
|
||||
annotations can be found in the readme of the object development kit readme on
|
||||
the Kitti homepage.
|
||||
"""
|
||||
_HOMEPAGE_URL = "http://www.cvlibs.net/datasets/kitti/"
|
||||
_DATA_URL = "https://s3.eu-central-1.amazonaws.com/avg-kitti"
|
||||
_RAW_URL = "https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data"
|
||||
_DRIVE_URL = "https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/{}/{}.zip"
|
||||
_DEPTH_FNAME = "data_depth_annotated.zip"
|
||||
|
||||
|
||||
class KittiDepth(tfds.core.GeneratorBasedBuilder):
|
||||
"""Kitti dataset."""
|
||||
|
||||
VERSION = tfds.core.Version("3.2.0")
|
||||
SUPPORTED_VERSIONS = [
|
||||
tfds.core.Version("3.1.0"),
|
||||
]
|
||||
RELEASE_NOTES = {
|
||||
"3.2.0": "Initial Implementation."
|
||||
}
|
||||
|
||||
def _info(self):
|
||||
return tfds.core.DatasetInfo(
|
||||
builder=self,
|
||||
description=_DESCRIPTION,
|
||||
features=tfds.features.FeaturesDict({
|
||||
"image": tfds.features.Image(shape=(None, None, 3)),
|
||||
"depth": tfds.features.Tensor(shape=(None, None), dtype=tf.float64),
|
||||
}),
|
||||
supervised_keys=("image", "depth"),
|
||||
homepage=_HOMEPAGE_URL,
|
||||
citation=_CITATION,
|
||||
)
|
||||
|
||||
def _split_generators(self, dl_manager: tfds.download.DownloadManager):
|
||||
depth_map_name = "depth_map"
|
||||
filenames = {
|
||||
depth_map_name: _DATA_URL + "/" + _DEPTH_FNAME,
|
||||
}
|
||||
files = dl_manager.download(filenames)
|
||||
|
||||
img_files = {}
|
||||
|
||||
# for fpath, fobj in dl_manager.iter_archive(files[depth_map_name]):
|
||||
for fpath, fobj in iter_zip(files[depth_map_name]):
|
||||
# Save all drives into a dict, then batch download them.
|
||||
drive = fpath.split(os.path.sep)[1]
|
||||
if img_files.get(drive) is None:
|
||||
# Strip off the sync so that we can get the correct url, but still use the synced files
|
||||
img_files[drive] = _DRIVE_URL.format(drive[0:-5], drive)
|
||||
|
||||
dl_img_files = dl_manager.download(img_files)
|
||||
|
||||
# TODO: Download test selections (no depths)
|
||||
|
||||
return [
|
||||
tfds.core.SplitGenerator(
|
||||
name=tfds.Split.TRAIN,
|
||||
gen_kwargs={
|
||||
"depths": iter_zip(files[depth_map_name]),
|
||||
"depth_folder": "train",
|
||||
"img_paths": dl_img_files,
|
||||
}),
|
||||
tfds.core.SplitGenerator(
|
||||
name=tfds.Split.VALIDATION,
|
||||
gen_kwargs={
|
||||
"depths": iter_zip(files[depth_map_name]),
|
||||
"depth_folder": "val",
|
||||
"img_paths": dl_img_files,
|
||||
}),
|
||||
# TODO: Test? This would probably be the selection: data_depth_selection.zip
|
||||
# It has no annotations though
|
||||
# tfds.core.SplitGenerator(
|
||||
# name=tfds.Split.TEST,
|
||||
# gen_kwargs={
|
||||
# }),
|
||||
]
|
||||
|
||||
def _generate_examples(self, depths, depth_folder, img_paths):
|
||||
"""Yields images and annotations.
|
||||
|
||||
Args:
|
||||
depths: iterator over folder of downloaded depth maps
|
||||
|
||||
Yields:
|
||||
A tuple containing the example's key, and the example.
|
||||
"""
|
||||
Image = tfds.core.lazy_imports.PIL_Image
|
||||
|
||||
for fpath, fobj in depths:
|
||||
fpath_split = fpath.split(os.path.sep)
|
||||
if fpath_split[0] != depth_folder:
|
||||
continue
|
||||
# Convert to numpy, apply transforms, yield the image.
|
||||
# Same as the kitti devkit
|
||||
depth_png = np.array(Image.open(fobj), dtype=int)
|
||||
depth = depth_png.astype(np.float) / 256.
|
||||
depth[depth_png == 0] = -1.
|
||||
|
||||
zip_iterator = iter_zip(img_paths[fpath.split(os.path.sep)[1]])
|
||||
img = None
|
||||
|
||||
# TODO: Make this faster, currently takes 2.7s per example
|
||||
# Tried keeping the iterator open to save time but this rarely ended up actually working
|
||||
for img_path, img_obj in zip_iterator:
|
||||
img_path_split = img_path.split(os.path.sep)
|
||||
# We want image_2 for the right sight
|
||||
if img_path_split[2] == fpath_split[-2] and img_path_split[-1] == fpath_split[-1]:
|
||||
img = img_obj
|
||||
break
|
||||
yield "{}_{}_{}".format(fpath_split[1], fpath_split[4], fpath_split[5]), {
|
||||
"image": np.array(Image.open(img)), "depth": depth}
|
||||
|
||||
|
||||
#### WARNING: THESE ARE HERE IN PLACE OF THE OFFICIAL DownloadManager#iter_archive implementation as it's broken on windows zip files #####
|
||||
#### See here, and if it's resolved this can be removed: https://github.com/tensorflow/tensorflow/issues/35630#issuecomment-600811558 #####
|
||||
# Taken directly from extractor.py
|
||||
def iter_zip(arch_f):
|
||||
"""Iterate over zip archive."""
|
||||
# This is modified to just open and iterate through the zip file
|
||||
with open(arch_f, 'rb') as fobj:
|
||||
z = zipfile.ZipFile(fobj)
|
||||
for member in z.infolist():
|
||||
extract_file = z.open(member)
|
||||
if member.is_dir(): # Filter directories # pytype: disable=attribute-error
|
||||
continue
|
||||
path = _normpath(member.filename)
|
||||
if not path:
|
||||
continue
|
||||
yield (path, extract_file)
|
||||
|
||||
|
||||
def _normpath(path):
|
||||
path = os.path.normpath(path)
|
||||
if (path.startswith('.')
|
||||
or os.path.isabs(path)
|
||||
or path.endswith('~')
|
||||
or os.path.basename(path).startswith('.')):
|
||||
return None
|
||||
return path
|
||||
Reference in New Issue
Block a user