Refactor python module structure

This commit is contained in:
Piv
2020-03-05 21:43:22 +10:30
parent 043c8783a4
commit 82a214c209
24 changed files with 20 additions and 120 deletions

0
tracking/__init__.py Normal file
View File

127
tracking/algorithms.py Normal file
View File

@@ -0,0 +1,127 @@
import math
class Group:
def __init__(self, number, points=[]):
self._points = points
self._number = number
self._minX = None
self._maxX = None
self._minY = None
self._maxY = None
def add_point(self, point):
self._points.append(point)
self._update_min_max(point)
def get_points(self):
return self._points
@property
def number(self):
return self._number
@number.setter
def number(self, number):
self._number = number
def _update_min_max(self, new_point):
'''
Updates the in and max points for this group.
This is to determine when assigning groups whether the
same group is selected.
'''
converted_point = convert_lidar_to_cartesian(new_point)
if self._minX is None or self._minX > converted_point[0]:
self._minX = converted_point[0]
if self._maxX is None or self._maxX < converted_point[0]:
self._maxX = converted_point[0]
if self._minY is None or self._minY > converted_point[1]:
self._minY = converted_point[1]
if self._maxY is None or self._maxY < converted_point[1]:
self._maxY = converted_point[1]
def get_minX(self):
return self._minY
def get_maxX(self):
return self._maxY
def get_minY(self):
return self._minY
def get_maxY(self):
return self._maxY
def convert_lidar_to_cartesian(new_point):
x = new_point[2] * math.sin(new_point[1])
y = new_point[2] * math.cos(new_point[1])
return (x, y)
def calc_groups(scan):
'''
Calculates groups of points from a lidar scan. The scan should
already be sorted.
Should return all groups.
'''
prevPoint = None
currentGroup = None
allGroups = []
currentGroupNumber = 0
# assume the list is already sorted.
for point in scan:
if prevPoint is None:
prevPoint = point
continue
# Distances are in mm.
# within 1cm makes a group. Will need to play around with this.
if (point[2] - prevPoint[2]) ** 2 < 10 ** 2:
if currentGroup is None:
currentGroup = Group(currentGroupNumber)
allGroups.append(currentGroup)
currentGroup.add_point(point)
else:
if currentGroup is not None:
currentGroupNumber += 1
currentGroup = None
prevPoint = point
return allGroups
def find_centre(group):
return ((group.get_maxX() + group.get_minX()) / 2, (group.get_maxY() + group.get_minY()) / 2)
def assign_groups(prev_groups, new_groups):
'''
Assigns group numbers to a new scan based on the groups of an old scan.
'''
for group in prev_groups:
old_centre = find_centre(prev_groups)
for new_group in new_groups:
new_centre = find_centre(new_group)
# They are considered the same if the new group and old group centres are within 5cm.
if ((new_centre[0] - old_centre[0]) ** 2 + (new_centre[1] - old_centre[1]) ** 2) < 50 ** 2:
new_group.set_number(group.get_number())
return new_groups
def updateCarVelocity(oldGroup, newGroup):
'''
Return a tuple (throttleChange, steeringChange) that should be
applied given the change in the centre of the groups.
'''
pass

61
tracking/lidar_cache.py Normal file
View File

@@ -0,0 +1,61 @@
import rplidar
from rplidar import RPLidar
from threading import Thread
from tracking import algorithms
import tracking.lidar_tracker_pb2 as tracker_pb
import zmq
import Messaging.message_factory as mf
import Messaging.messages as messages
class LidarCache():
'''
A class that retrieves scans from the lidar,
runs grouping algorithms between scans and
keeps a copy of the group data.
'''
run = True
tracking_group_number = -1
currentGroups = None
groupsChanged = []
port = None
def __init__(self, measurements=100):
self.lidar = RPLidar('/dev/ttyUSB0')
self.measurements = measurements
print('Info: ' + self.lidar.get_info())
print('Health: ' + self.lidar.get_health())
def start_cache(self):
if self.port is None:
print('ERROR: Port has not been set!')
return
self.thread = Thread(target=self.do_scanning)
self.thread.start()
def do_scanning(self):
'''
Performs a scan for the given number of iterations.
'''
# Create the 0MQ socket first. This should not be passed between threads.
self._mFactory = mf.getZmqPubSubStreamer(self.port)
for i, scan in enumerate(self.lidar.iter_scans(min_len=self.measurements)):
print('%d: Got %d measurments' % (i, len(scan)))
if(not self.run):
break
# Now process the groups.
if self.currentGroups is not None:
self.currentGroups = algorithms.assign_groups(
self.currentGroups, algorithms.calc_groups(scan))
else:
self.currentGroups = algorithms.calc_groups(scan)
def fireGroupsChanged(self):
# Send the updated groups to 0MQ socket.
self._mFactory.send_message_topic("lidar_map", messages.ProtoMessage(
message=tracker_pb.PointScan(points=[]).SerializeToString()))
def stop_scanning(self):
self.run = False

View File

@@ -0,0 +1,22 @@
import tracking.lidar_tracker_pb2 as lidar_tracker_pb2
from tracking.lidar_tracker_pb2_grpc import PersonTrackingServicer
from tracking.lidar_cache import LidarCache
from multiprocessing import Process
class LidarServicer(PersonTrackingServicer):
def __init__(self):
self.cache = LidarCache(measurements=100)
def set_tracking_group(self, request, context):
pass
def stop_tracking(self, request, context):
self.cache.stop_scanning()
def start_tracking(self, request, context):
'''
Starts the lidar cache.
'''
self.cache.start_cache()

View File

@@ -0,0 +1,237 @@
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: persontracking/lidar_tracker.proto
import sys
_b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1'))
from google.protobuf import descriptor as _descriptor
from google.protobuf import message as _message
from google.protobuf import reflection as _reflection
from google.protobuf import symbol_database as _symbol_database
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor.FileDescriptor(
name='persontracking/lidar_tracker.proto',
package='persontracking',
syntax='proto3',
serialized_options=_b('\n\031com.example.carcontrollerB\021MotorServiceProtoP\001'),
serialized_pb=_b('\n\"persontracking/lidar_tracker.proto\x12\x0epersontracking\"\x1b\n\nInt32Value\x12\r\n\x05value\x18\x01 \x01(\x05\"\x07\n\x05\x45mpty\">\n\x05Point\x12\r\n\x05\x61ngle\x18\x01 \x01(\x01\x12\x10\n\x08\x64istance\x18\x02 \x01(\x05\x12\x14\n\x0cgroup_number\x18\x03 \x01(\x05\"2\n\tPointScan\x12%\n\x06points\x18\x01 \x03(\x0b\x32\x15.persontracking.Point2\xe3\x01\n\x0ePersonTracking\x12I\n\x12set_tracking_group\x12\x1a.persontracking.Int32Value\x1a\x15.persontracking.Empty\"\x00\x12?\n\rstop_tracking\x12\x15.persontracking.Empty\x1a\x15.persontracking.Empty\"\x00\x12\x45\n\x0estart_tracking\x12\x1a.persontracking.Int32Value\x1a\x15.persontracking.Empty\"\x00\x42\x30\n\x19\x63om.example.carcontrollerB\x11MotorServiceProtoP\x01\x62\x06proto3')
)
_INT32VALUE = _descriptor.Descriptor(
name='Int32Value',
full_name='persontracking.Int32Value',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='value', full_name='persontracking.Int32Value.value', index=0,
number=1, type=5, cpp_type=1, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR),
],
extensions=[
],
nested_types=[],
enum_types=[
],
serialized_options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[
],
serialized_start=54,
serialized_end=81,
)
_EMPTY = _descriptor.Descriptor(
name='Empty',
full_name='persontracking.Empty',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
],
extensions=[
],
nested_types=[],
enum_types=[
],
serialized_options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[
],
serialized_start=83,
serialized_end=90,
)
_POINT = _descriptor.Descriptor(
name='Point',
full_name='persontracking.Point',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='angle', full_name='persontracking.Point.angle', index=0,
number=1, type=1, cpp_type=5, label=1,
has_default_value=False, default_value=float(0),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='distance', full_name='persontracking.Point.distance', index=1,
number=2, type=5, cpp_type=1, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR),
_descriptor.FieldDescriptor(
name='group_number', full_name='persontracking.Point.group_number', index=2,
number=3, type=5, cpp_type=1, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR),
],
extensions=[
],
nested_types=[],
enum_types=[
],
serialized_options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[
],
serialized_start=92,
serialized_end=154,
)
_POINTSCAN = _descriptor.Descriptor(
name='PointScan',
full_name='persontracking.PointScan',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='points', full_name='persontracking.PointScan.points', index=0,
number=1, type=11, cpp_type=10, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR),
],
extensions=[
],
nested_types=[],
enum_types=[
],
serialized_options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[
],
serialized_start=156,
serialized_end=206,
)
_POINTSCAN.fields_by_name['points'].message_type = _POINT
DESCRIPTOR.message_types_by_name['Int32Value'] = _INT32VALUE
DESCRIPTOR.message_types_by_name['Empty'] = _EMPTY
DESCRIPTOR.message_types_by_name['Point'] = _POINT
DESCRIPTOR.message_types_by_name['PointScan'] = _POINTSCAN
_sym_db.RegisterFileDescriptor(DESCRIPTOR)
Int32Value = _reflection.GeneratedProtocolMessageType('Int32Value', (_message.Message,), dict(
DESCRIPTOR = _INT32VALUE,
__module__ = 'persontracking.lidar_tracker_pb2'
# @@protoc_insertion_point(class_scope:persontracking.Int32Value)
))
_sym_db.RegisterMessage(Int32Value)
Empty = _reflection.GeneratedProtocolMessageType('Empty', (_message.Message,), dict(
DESCRIPTOR = _EMPTY,
__module__ = 'persontracking.lidar_tracker_pb2'
# @@protoc_insertion_point(class_scope:persontracking.Empty)
))
_sym_db.RegisterMessage(Empty)
Point = _reflection.GeneratedProtocolMessageType('Point', (_message.Message,), dict(
DESCRIPTOR = _POINT,
__module__ = 'persontracking.lidar_tracker_pb2'
# @@protoc_insertion_point(class_scope:persontracking.Point)
))
_sym_db.RegisterMessage(Point)
PointScan = _reflection.GeneratedProtocolMessageType('PointScan', (_message.Message,), dict(
DESCRIPTOR = _POINTSCAN,
__module__ = 'persontracking.lidar_tracker_pb2'
# @@protoc_insertion_point(class_scope:persontracking.PointScan)
))
_sym_db.RegisterMessage(PointScan)
DESCRIPTOR._options = None
_PERSONTRACKING = _descriptor.ServiceDescriptor(
name='PersonTracking',
full_name='persontracking.PersonTracking',
file=DESCRIPTOR,
index=0,
serialized_options=None,
serialized_start=209,
serialized_end=436,
methods=[
_descriptor.MethodDescriptor(
name='set_tracking_group',
full_name='persontracking.PersonTracking.set_tracking_group',
index=0,
containing_service=None,
input_type=_INT32VALUE,
output_type=_EMPTY,
serialized_options=None,
),
_descriptor.MethodDescriptor(
name='stop_tracking',
full_name='persontracking.PersonTracking.stop_tracking',
index=1,
containing_service=None,
input_type=_EMPTY,
output_type=_EMPTY,
serialized_options=None,
),
_descriptor.MethodDescriptor(
name='start_tracking',
full_name='persontracking.PersonTracking.start_tracking',
index=2,
containing_service=None,
input_type=_INT32VALUE,
output_type=_EMPTY,
serialized_options=None,
),
])
_sym_db.RegisterServiceDescriptor(_PERSONTRACKING)
DESCRIPTOR.services_by_name['PersonTracking'] = _PERSONTRACKING
# @@protoc_insertion_point(module_scope)

View File

@@ -0,0 +1,80 @@
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
import grpc
from tracking import lidar_tracker_pb2 as persontracking_dot_lidar__tracker__pb2
class PersonTrackingStub(object):
# missing associated documentation comment in .proto file
pass
def __init__(self, channel):
"""Constructor.
Args:
channel: A grpc.Channel.
"""
self.set_tracking_group = channel.unary_unary(
'/persontracking.PersonTracking/set_tracking_group',
request_serializer=persontracking_dot_lidar__tracker__pb2.Int32Value.SerializeToString,
response_deserializer=persontracking_dot_lidar__tracker__pb2.Empty.FromString,
)
self.stop_tracking = channel.unary_unary(
'/persontracking.PersonTracking/stop_tracking',
request_serializer=persontracking_dot_lidar__tracker__pb2.Empty.SerializeToString,
response_deserializer=persontracking_dot_lidar__tracker__pb2.Empty.FromString,
)
self.start_tracking = channel.unary_unary(
'/persontracking.PersonTracking/start_tracking',
request_serializer=persontracking_dot_lidar__tracker__pb2.Int32Value.SerializeToString,
response_deserializer=persontracking_dot_lidar__tracker__pb2.Empty.FromString,
)
class PersonTrackingServicer(object):
# missing associated documentation comment in .proto file
pass
def set_tracking_group(self, request, context):
# missing associated documentation comment in .proto file
pass
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def stop_tracking(self, request, context):
# missing associated documentation comment in .proto file
pass
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def start_tracking(self, request, context):
# missing associated documentation comment in .proto file
pass
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def add_PersonTrackingServicer_to_server(servicer, server):
rpc_method_handlers = {
'set_tracking_group': grpc.unary_unary_rpc_method_handler(
servicer.set_tracking_group,
request_deserializer=persontracking_dot_lidar__tracker__pb2.Int32Value.FromString,
response_serializer=persontracking_dot_lidar__tracker__pb2.Empty.SerializeToString,
),
'stop_tracking': grpc.unary_unary_rpc_method_handler(
servicer.stop_tracking,
request_deserializer=persontracking_dot_lidar__tracker__pb2.Empty.FromString,
response_serializer=persontracking_dot_lidar__tracker__pb2.Empty.SerializeToString,
),
'start_tracking': grpc.unary_unary_rpc_method_handler(
servicer.start_tracking,
request_deserializer=persontracking_dot_lidar__tracker__pb2.Int32Value.FromString,
response_serializer=persontracking_dot_lidar__tracker__pb2.Empty.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
'persontracking.PersonTracking', rpc_method_handlers)
server.add_generic_rpc_handlers((generic_handler,))