Move car python code into src subdirectory
This commit is contained in:
51
car/src/DecisionSystem/CentralisedDecision/ballotvoter.py
Normal file
51
car/src/DecisionSystem/CentralisedDecision/ballotvoter.py
Normal file
@@ -0,0 +1,51 @@
|
||||
import json
|
||||
from DecisionSystem.messages import ConnectSwarm, SubmitVote, Message, deserialise, RequestVote, ClientVoteRequest, VoteResult
|
||||
from multiprocessing import Pool
|
||||
from messenger import Messenger
|
||||
|
||||
class BallotVoter:
|
||||
def __init__(self, on_vote, handle_agreement, messenger: Messenger):
|
||||
self.messenger = messenger
|
||||
self.messenger.add_message_callback(self.on_message)
|
||||
self.messenger.add_connect(self.on_connect)
|
||||
self.on_vote = on_vote
|
||||
self.handle_agreement = handle_agreement
|
||||
|
||||
def on_connect(self, rc):
|
||||
print("Connected with result code " + str(rc))
|
||||
|
||||
# Tell commander we are now connected.
|
||||
self.send_connect()
|
||||
|
||||
def on_message(self, message):
|
||||
print("Message Received!")
|
||||
messageD = deserialise(message.payload)
|
||||
print("Message Type: " + messageD.type)
|
||||
# Ok message.
|
||||
if messageD.type == RequestVote().type:
|
||||
print('Received vote message')
|
||||
self.submit_vote()
|
||||
elif messageD.type == "listening":
|
||||
self.send_connect()
|
||||
elif messageD.type == VoteResult.type:
|
||||
self.handle_agreement(messageD.data["vote"])
|
||||
|
||||
def submit_vote(self):
|
||||
v = self.on_vote()
|
||||
if v == None:
|
||||
print('Could not get vote')
|
||||
return
|
||||
print("Got Vote")
|
||||
vote = SubmitVote(v, self.messenger.id)
|
||||
print('Created Vote Message')
|
||||
self.messenger.broadcast_message(self.messenger.swarm, vote.serialise())
|
||||
print('published vote')
|
||||
|
||||
def send_connect(self):
|
||||
# Send a connected message to let any commanders know that
|
||||
# it is available.
|
||||
self.messenger.broadcast_message(self.messenger.swarm, ConnectSwarm(self.messenger.id).serialise())
|
||||
|
||||
def request_vote(self):
|
||||
"""Sends a request to the leader to start collecting votes."""
|
||||
self.messenger.broadcast_message(self.messenger.swarm, ClientVoteRequest(self.messenger.id).serialise())
|
||||
95
car/src/DecisionSystem/CentralisedDecision/cameraserver.py
Normal file
95
car/src/DecisionSystem/CentralisedDecision/cameraserver.py
Normal file
@@ -0,0 +1,95 @@
|
||||
from DecisionSystem.CentralisedDecision.ballotvoter import BallotVoter
|
||||
from DecisionSystem.CentralisedDecision.messenger import MqttMessenger
|
||||
import numpy as np
|
||||
import cv2
|
||||
import time
|
||||
import argparse
|
||||
import os.path
|
||||
import sys
|
||||
from GestureRecognition.simplehandrecogniser import SimpleHandRecogniser
|
||||
from threading import Thread
|
||||
from queue import Queue
|
||||
|
||||
import MyRaft.node as raft
|
||||
import MyRaft.leader as leader
|
||||
import DecisionSystem.CentralisedDecision.commander as commander
|
||||
import DecisionSystem.CentralisedDecision.messenger as messenger
|
||||
import DecisionSystem.CentralisedDecision.ballotvoter as voter
|
||||
|
||||
print("Parsing args")
|
||||
parser = argparse.ArgumentParser(description="Runs a file with OpenCV and gets consensus from the swarm.")
|
||||
|
||||
parser.add_argument('-V', '--video', help="Path to video file.")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
recogniser = SimpleHandRecogniser(None)
|
||||
|
||||
# Checks if video file is specified and if that file exists.
|
||||
if(args.video):
|
||||
print('finding video')
|
||||
if not os.path.isfile(args.video):
|
||||
print("Input video file ", args.video, " doesn't exist")
|
||||
sys.exit(1)
|
||||
else:
|
||||
# Exit if no video file specified - we aren't using webcam here.
|
||||
sys.exit(1)
|
||||
|
||||
def on_vote():
|
||||
# Get the current frame of the camera and process what hand
|
||||
# is currently being seen.
|
||||
print('getting frame')
|
||||
# Need to copy rather than just take a reference, as frame will
|
||||
# constantly be changing.
|
||||
global vd
|
||||
recogniser.set_frame(np.copy(vd.frame))
|
||||
print('Got frame, voting with recogniser')
|
||||
return recogniser.get_gesture()
|
||||
|
||||
def connect_to_broker(mqtt):
|
||||
print("Connecting to broker")
|
||||
max_collisions = 100
|
||||
collisions = 1
|
||||
while not mqtt.connect() and collisions <= max_collisions:
|
||||
time.sleep(2 ** collisions - 1)
|
||||
print("Reconnecting in %s" %(2 ** collisions - 1))
|
||||
collisions += 1
|
||||
|
||||
mqtt = MqttMessenger()
|
||||
v = BallotVoter(on_vote, mqtt)
|
||||
|
||||
def on_disconnect(rc):
|
||||
print("Client disconnected from broker")
|
||||
i = input("Would you like to reconnnect? (y|n)")
|
||||
if i == 'y':
|
||||
global mqtt
|
||||
connect_to_broker(mqtt)
|
||||
|
||||
mqtt.add_disconnect_callback(on_disconnect)
|
||||
connect_to_broker(mqtt)
|
||||
|
||||
# Start the video capture at the next whole minute.
|
||||
current_time_sec = time.gmtime(time.time()).tm_sec
|
||||
if current_time_sec < 40:
|
||||
time.sleep(60 - current_time_sec)
|
||||
else:
|
||||
time.sleep(60 - current_time_sec + 60)
|
||||
print('loading video')
|
||||
|
||||
|
||||
|
||||
print('Press q to quit the server, g to get votes/consensus')
|
||||
|
||||
while True:
|
||||
if vd.frame is None:
|
||||
continue
|
||||
frame = np.copy(vd.frame)
|
||||
cv2.imshow('Frame', frame)
|
||||
k = cv2.waitKey(33)
|
||||
if k == ord('q'):
|
||||
break
|
||||
elif k == -1:
|
||||
continue
|
||||
elif k == ord('g'):
|
||||
# Get votes
|
||||
pass
|
||||
15
car/src/DecisionSystem/CentralisedDecision/central_server.py
Normal file
15
car/src/DecisionSystem/CentralisedDecision/central_server.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from DecisionSystem.CentralisedDecision import commander
|
||||
from DecisionSystem.CentralisedDecision.messenger import MqttMessenger
|
||||
|
||||
mqtt = MqttMessenger()
|
||||
c = commander.Commander(mqtt, 10)
|
||||
mqtt.connect()
|
||||
|
||||
f = input("Press any key and enter other than q to get current observation of the swarm: ")
|
||||
|
||||
while f != "q":
|
||||
print("Vote is: ")
|
||||
print(c.get_votes())
|
||||
f = input("Press any key and enter other than q to get current observation of the swarm: ")
|
||||
|
||||
print("Thanks for trying!")
|
||||
@@ -0,0 +1,106 @@
|
||||
"""This module provides an instance of the centralised, distributed voter"""
|
||||
|
||||
from queue import Queue
|
||||
import json
|
||||
import argparse
|
||||
|
||||
import numpy as np
|
||||
import cv2
|
||||
|
||||
import MyRaft.node as raft
|
||||
import MyRaft.leader as leader
|
||||
import DecisionSystem.CentralisedDecision.commander as commander
|
||||
import DecisionSystem.CentralisedDecision.messenger as messenger
|
||||
import DecisionSystem.CentralisedDecision.ballotvoter as voter
|
||||
import DecisionSystem.CentralisedDecision.videoget as videoget
|
||||
import GestureRecognition.simplehandrecogniser as shr
|
||||
import GestureRecognition.starkaleid as sk
|
||||
|
||||
class Instance:
|
||||
"""An instance of the centralised, distributed approach to voting.
|
||||
"""
|
||||
def __init__(self, node_config='config.json', video_file=0):
|
||||
with open(node_config) as f:
|
||||
self.cfg= json.load(f)
|
||||
self.mqtt = messenger.MqttMessenger(self.cfg)
|
||||
self.we_lead = False
|
||||
self.node = raft.RaftGrpcNode(node_config)
|
||||
print("Node initialised")
|
||||
self.node.add_state_change(self.on_state_changed)
|
||||
|
||||
self.voter = voter.BallotVoter(self.on_vote, self.handle_agreement, self.mqtt)
|
||||
self.commander = commander.Commander(self.mqtt)
|
||||
self.recogniser = shr.SimpleHandRecogniser(None)
|
||||
|
||||
self.last_vote = -1
|
||||
|
||||
self.q = Queue(5)
|
||||
self.frame = None
|
||||
self.vd = videoget.VideoGet(self.q, video_file)
|
||||
|
||||
self.kaleid = False
|
||||
print("Initialised the instance")
|
||||
|
||||
def on_state_changed(self):
|
||||
"""Callback method for state of the raft node changing"""
|
||||
if isinstance(self.node._current_state, leader.Leader):
|
||||
# We are now the commander (or leader)
|
||||
self.commander = commander.Commander(self.mqtt)
|
||||
else:
|
||||
# No longer or never were a leader.
|
||||
try:
|
||||
del(self.commander)
|
||||
except SyntaxError:
|
||||
pass
|
||||
|
||||
def start(self):
|
||||
self.vd.start()
|
||||
self.mqtt.connect()
|
||||
go = True
|
||||
while go:
|
||||
if self.kaleid:
|
||||
go = self.show_kaleidoscope
|
||||
else:
|
||||
go = self.show_normal
|
||||
|
||||
def show_normal(self):
|
||||
self.frame = np.copy(self.q.get())
|
||||
cv2.imshow('Frame', self.frame)
|
||||
if cv2.waitKey(1) & 0xFF == ord('q'):
|
||||
return False
|
||||
elif cv2.waitKey(1) & 0xFF == ord('g'):
|
||||
self.voter.request_vote()
|
||||
|
||||
def show_kaleidoscope(self):
|
||||
self.frame = sk.make_kaleidoscope(np.copy(self.q.get()), 12)
|
||||
cv2.imshow('Frame', self.frame)
|
||||
if cv2.waitKey(1) & 0xFF == ord('q'):
|
||||
return False
|
||||
elif cv2.waitKey(1) & 0xFF == ord('g'):
|
||||
self.voter.request_vote()
|
||||
|
||||
def on_vote(self):
|
||||
# Get the current frame of the camera and process what hand
|
||||
# is currently being seen.
|
||||
print('getting frame')
|
||||
# Need to copy rather than just take a reference, as frame will
|
||||
# constantly be changing.
|
||||
self.recogniser.set_frame(np.copy(self.frame))
|
||||
print('Got frame, voting with recogniser')
|
||||
gesture = self.recogniser.get_gesture()
|
||||
self.last_vote = gesture
|
||||
return gesture
|
||||
|
||||
def handle_agreement(self, vote):
|
||||
if vote == 5:
|
||||
self.kaleid = True
|
||||
else:
|
||||
self.kaleid = False
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(description="An instance of CAIDE")
|
||||
|
||||
if __name__ == "__main__":
|
||||
instance = Instance(video_file="/Users/piv/Documents/Projects/Experiments/Camera1/video.mp4")
|
||||
instance.start()
|
||||
|
||||
119
car/src/DecisionSystem/CentralisedDecision/commander.py
Normal file
119
car/src/DecisionSystem/CentralisedDecision/commander.py
Normal file
@@ -0,0 +1,119 @@
|
||||
import time
|
||||
from DecisionSystem.messages import Message, CommanderWill, RequestVote, GetSwarmParticipants, deserialise, ClientVoteRequest, VoteResult
|
||||
import json
|
||||
import numpy as np
|
||||
|
||||
class Commander:
|
||||
currentVote = None
|
||||
|
||||
# Stores voters that connect to maintain a majority.
|
||||
# Voters who do not vote in latest round are removed.
|
||||
_connectedVoters = []
|
||||
# Dict has format: {clientId: vote}
|
||||
_votes = {}
|
||||
_taking_votes = False
|
||||
|
||||
def __init__(self, messenger, timeout = 60):
|
||||
'''
|
||||
Initial/default waiting time is 1 minute for votes to come in.
|
||||
'''
|
||||
self.timeout = timeout
|
||||
|
||||
self._messenger = messenger
|
||||
self._messenger.add_connect(self.on_connect)
|
||||
self._messenger.add_message_callback(self.on_message)
|
||||
self._messenger.add_disconnect_callback(self.on_disconnect)
|
||||
print('Connecting')
|
||||
|
||||
def make_decision(self):
|
||||
# Should change this to follow strategy pattern, for different implementations of
|
||||
# making a decision on the votes.
|
||||
print("Making a decision")
|
||||
votes = self._votes.values()
|
||||
print(type(votes))
|
||||
dif_votes = {}
|
||||
|
||||
for vote in votes:
|
||||
# Get the count of different votes.
|
||||
if vote in dif_votes:
|
||||
dif_votes[vote] = dif_votes[vote] + 1
|
||||
else:
|
||||
dif_votes[vote] = 1
|
||||
|
||||
max_vote = ""
|
||||
max_vote_num = 0
|
||||
# Should try using a numpy array for this.
|
||||
|
||||
for vote in dif_votes.keys():
|
||||
if dif_votes[vote] > max_vote_num:
|
||||
max_vote = vote
|
||||
max_vote_num = dif_votes[vote]
|
||||
|
||||
print("Made Decision!")
|
||||
return max_vote
|
||||
|
||||
def get_votes(self):
|
||||
# Should abstract messaging to another class.
|
||||
print("Gathering Votes")
|
||||
self._taking_votes = True
|
||||
# Publish a message that votes are needed.
|
||||
print("Sending request message")
|
||||
self._messenger.broadcast_message(self._messenger.swarm, RequestVote(self._messenger.id).serialise())
|
||||
print("published message")
|
||||
time.sleep(self.timeout)
|
||||
self._taking_votes = False
|
||||
# TODO: Work out how to broadcast votes back to the swarm, maybe using raft?
|
||||
return self.make_decision()
|
||||
|
||||
def on_message(self, message):
|
||||
print("Message Received")
|
||||
messageD = None
|
||||
try:
|
||||
messageD = deserialise(message.payload)
|
||||
except:
|
||||
print("Incorrect Message Has Been Sent")
|
||||
return
|
||||
|
||||
# Need to consider that a malicious message may have a type with incorrect subtypes.
|
||||
if messageD.type == "connect":
|
||||
print("Voter connected!")
|
||||
# Voter just connected/reconnnected.
|
||||
if not messageD["client"] in self._connectedVoters:
|
||||
self._connectedVoters.append(messageD["client"])
|
||||
elif messageD.type == "vote":
|
||||
print("Received a vote!")
|
||||
# Voter is sending in their vote.
|
||||
print(messageD.data["vote"])
|
||||
print("From: ", messageD.sender)
|
||||
if self._taking_votes:
|
||||
# Commander must have requested their taking votes, and the timeout
|
||||
# has not occurred.
|
||||
# Only add vote to list if the client has not already voted.
|
||||
if messageD.sender not in self._votes:
|
||||
self._votes[messageD.sender] = int(messageD.data["vote"])
|
||||
elif messageD.type == ClientVoteRequest().type:
|
||||
# received a request to get votes/consensus.
|
||||
self.get_votes()
|
||||
|
||||
elif messageD.type == "disconnected":
|
||||
print("Voter disconnected :(")
|
||||
self._connectedVoters.remove(messageD.sender)
|
||||
|
||||
def on_connect(self, rc):
|
||||
# Subscribes now handled by the mqtt messenger, this is just here
|
||||
# for convenience later.
|
||||
pass
|
||||
|
||||
def get_participants(self):
|
||||
self._messenger.broadcast_message(self._messenger.swarm, GetSwarmParticipants().serialise())
|
||||
# Commander needs a will message too, for the decentralised version, so the
|
||||
# voters know to pick a new commander.
|
||||
# If using apache zookeeper this won't be needed.
|
||||
# That's the wrong method for setting a will.
|
||||
# self.client.publish("swarm1/voters", CommanderWill(self.client._client_id).serialise())
|
||||
|
||||
def on_disconnect(self, rc):
|
||||
pass
|
||||
|
||||
def propogate_result(self, result):
|
||||
self._messenger.broadcast_message(self._messenger.swarm, )
|
||||
138
car/src/DecisionSystem/CentralisedDecision/messenger.py
Normal file
138
car/src/DecisionSystem/CentralisedDecision/messenger.py
Normal file
@@ -0,0 +1,138 @@
|
||||
import paho.mqtt.client as mqtt
|
||||
import json
|
||||
import random
|
||||
|
||||
class Messenger:
|
||||
_connect_callbacks = []
|
||||
_disconnect_callbacks = []
|
||||
_message_callbacks = []
|
||||
|
||||
def broadcast_message(self, message):
|
||||
"""
|
||||
Broadcasts the specified message to the swarm based upon its topic(or group).
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def unicast_message(self, target, message):
|
||||
"""
|
||||
Broadcasts the specified message to the single target.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def connect(self):
|
||||
"""
|
||||
Connect to the swarm.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def disconnect(self):
|
||||
"""
|
||||
Disconnect from the swarm.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def add_connect(self, connect):
|
||||
"""
|
||||
Adds a callback to do something else once we are connected.
|
||||
"""
|
||||
self._connect_callbacks.append(connect)
|
||||
|
||||
def on_connect(self, code = None):
|
||||
"""
|
||||
Called once the messenger connects to the swarm.
|
||||
"""
|
||||
for cb in self._connect_callbacks:
|
||||
cb(code)
|
||||
|
||||
def on_disconnect(self, code = None):
|
||||
"""
|
||||
Called when the messenger is disconnected from the swarm.
|
||||
"""
|
||||
for cb in self._disconnect_callbacks:
|
||||
cb(code)
|
||||
|
||||
def add_disconnect_callback(self, on_disconnect):
|
||||
"""
|
||||
Adds a callback for when the messenger is disconnected.
|
||||
"""
|
||||
self._disconnect_callbacks.append(on_disconnect)
|
||||
|
||||
def add_message_callback(self, on_message):
|
||||
"""
|
||||
Adds a callback
|
||||
"""
|
||||
self._message_callbacks.append(on_message)
|
||||
|
||||
def on_message(self, message):
|
||||
"""
|
||||
Called when the messenger receives a message.
|
||||
"""
|
||||
for cb in self._message_callbacks:
|
||||
cb(message)
|
||||
|
||||
@property
|
||||
def id(self):
|
||||
"""
|
||||
The id for this messenger that is being used in communication.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def swarm(self):
|
||||
"""
|
||||
Gets the name of the swarm this instance is a part of.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class MqttMessenger(Messenger):
|
||||
"""A messenger that uses MQTT."""
|
||||
def __init__(self, configuration):
|
||||
self._cfg = configuration
|
||||
self._client = mqtt.Client(client_id=str(random.randint(0,500)))
|
||||
self._client.on_connect = self.on_connect
|
||||
self._client.on_message = self.on_message
|
||||
self._client.on_disconnect = self.on_disconnect
|
||||
|
||||
def on_message(self, client, userdata, message):
|
||||
Messenger.on_message(self, message)
|
||||
|
||||
def on_connect(self, client, userdata, flags, rc):
|
||||
# Subscribe to the swarm specified in the config.
|
||||
self._client.subscribe(self._cfg['mqtt']['swarm'])
|
||||
|
||||
# Also subscribe to our own topic for unicast messages.
|
||||
self._client.subscribe(self._cfg['mqtt']['swarm'] + str(self._client._client_id))
|
||||
Messenger.on_connect(self, rc)
|
||||
|
||||
def on_disconnect(self, client, userdata, rc):
|
||||
Messenger.on_disconnect(self, rc)
|
||||
|
||||
def broadcast_message(self, message):
|
||||
self._client.publish(self._cfg['mqtt']['swarm'], message, qos=1)
|
||||
|
||||
def unicast_message(self, target, message):
|
||||
self._client.publish(target, message, qos=1)
|
||||
|
||||
def connect(self):
|
||||
try:
|
||||
self._client.connect(self._cfg['mqtt']['host'], \
|
||||
int(self._cfg['mqtt']['port']), \
|
||||
int(self._cfg['mqtt']['timeout']))
|
||||
except:
|
||||
print("Could not connect to broker")
|
||||
return False
|
||||
|
||||
self._client.loop_start()
|
||||
return True
|
||||
|
||||
def disconnect(self):
|
||||
self._client.disconnect()
|
||||
|
||||
@property
|
||||
def id(self):
|
||||
return self._client._client_id
|
||||
|
||||
@property
|
||||
def swarm(self):
|
||||
return self._cfg['mqtt']['swarm']
|
||||
45
car/src/DecisionSystem/CentralisedDecision/videoget.py
Normal file
45
car/src/DecisionSystem/CentralisedDecision/videoget.py
Normal file
@@ -0,0 +1,45 @@
|
||||
import numpy as np
|
||||
import cv2
|
||||
from threading import Thread
|
||||
from queue import Queue
|
||||
import time
|
||||
|
||||
class VideoGet:
|
||||
'''
|
||||
Code taken from Najam R Syed, available here:
|
||||
https://github.com/nrsyed/computer-vision/tree/master/multithread
|
||||
'''
|
||||
def __init__(self, q, src):
|
||||
'''
|
||||
Must provide a source so we don't accidently start camera at work.
|
||||
'''
|
||||
self._stream = cv2.VideoCapture(src)
|
||||
(self.grabbed, self.frame) = self._stream.read()
|
||||
self.stopped = False
|
||||
self.q = q
|
||||
self.q.put(np.copy(self.frame))
|
||||
self.src = src
|
||||
|
||||
def start(self):
|
||||
Thread(target=self.get, args=()).start()
|
||||
return self
|
||||
|
||||
def get(self):
|
||||
while not self.stopped:
|
||||
if not self.grabbed:
|
||||
# self.stopped = True
|
||||
print('frame not grabbed')
|
||||
self._stream.release()
|
||||
self._stream = cv2.VideoCapture(self.src)
|
||||
# time.sleep(2)
|
||||
self.grabbed, self.frame = self._stream.read()
|
||||
else:
|
||||
(self.grabbed, self.frame) = self._stream.read()
|
||||
if self.q.full():
|
||||
self.q.get()
|
||||
self.q.put(np.copy(self.frame))
|
||||
time.sleep(0.03) # Approximately 30fps
|
||||
# Start a new feed.
|
||||
|
||||
def stop(self):
|
||||
self.stopped = True
|
||||
128
car/src/DecisionSystem/DecentralisedActivityFusion/voter.py
Normal file
128
car/src/DecisionSystem/DecentralisedActivityFusion/voter.py
Normal file
@@ -0,0 +1,128 @@
|
||||
import paho.mqtt.client as mqtt
|
||||
import time
|
||||
import json
|
||||
import umsgpack
|
||||
import numpy as np
|
||||
|
||||
class Voter:
|
||||
'''
|
||||
This class acts to replicate sensor information with the network to come to a consensus
|
||||
of an activity occurrance. This is based upon research by Song et al. available at:
|
||||
https://ieeexplore.ieee.org/document/5484586
|
||||
|
||||
The main advantage of this approach, as apposed to techniques such as by using zookeeper
|
||||
or consul, is it can be completely decentralised and so works without a central server,
|
||||
or needing to elect a central server. Additionally, it does not require all nodes
|
||||
to run a Zookeeper/Consul server instance, which were not designed for these constrained
|
||||
combat environments, which will fail if half the nodes fail, and also use a lot of resources
|
||||
for handling services not required by this task.
|
||||
|
||||
The original approach in the paper requires some previous training before sensing, so
|
||||
that there is a probability of a given action based upon the previous set of actions.
|
||||
'''
|
||||
_votes = {}
|
||||
_connected_voters = []
|
||||
_taking_votes = False
|
||||
|
||||
def __init__(self, on_vote, swarm_name):
|
||||
'''
|
||||
on_vote: Callback to get the required vote to broadcast.
|
||||
'''
|
||||
# Load config file
|
||||
cfg = None
|
||||
with open('config.json') as json_config:
|
||||
cfg = json.load(json_config)
|
||||
self._cfg = cfg
|
||||
self.on_vote = on_vote
|
||||
self._swarm = swarm_name
|
||||
self._client = mqtt.Client()
|
||||
self._client.on_message = self.on_message
|
||||
self._client.on_connect = self.on_connect
|
||||
self._client.connect(cfg["mqtt"]["host"], cfg["mqtt"]["port"], cfg["mqtt"]["timeout"])
|
||||
self._client.loop_start()
|
||||
|
||||
def submit_vote(self):
|
||||
# Publish to swarm where all other voters will receive a vote.
|
||||
self._client.publish(self._swarm, self.collect_vote)
|
||||
self._taking_votes = True
|
||||
time.sleep(self._cfg["mqtt"]["timeout"])
|
||||
self._taking_votes = False
|
||||
# Wait a certain amount of time for responses, then fuse the information.
|
||||
self.fuse_algorithm()
|
||||
|
||||
# Need the error and number of timestamps since voting started to finalise the consensus.
|
||||
|
||||
def fuse_algorithm(self):
|
||||
# First calculate vi -> the actual vote that is taken
|
||||
# (Or the probability that the observation is a label for each)
|
||||
# We're just going to be doing 1 for the detected and 0 for all others.
|
||||
# vi is for each hand (action in paper), but we're just going to do a single
|
||||
# hand for our purposes. Will be able to use the CNN for all hands/gestures if we want to.
|
||||
vi = np.zeros(6,1)
|
||||
# Set correct vi.
|
||||
vote = self.on_vote()
|
||||
vi[vote] = 1
|
||||
# Now send this off to the other nodes. Potentially using gossip...
|
||||
|
||||
# Set diagonal of ANDvi to elements of vi.
|
||||
# This should actually be ANDvj, as it is for each observation received.
|
||||
ANDvi = np.diag(vi.flatten())
|
||||
|
||||
# Nee
|
||||
|
||||
# M is the probability of going from one state to the next, which
|
||||
# is assumed to be uniform for our situation - someone is just as likely
|
||||
# to raise 5 fingers from two or any other.
|
||||
# And so a 6x6 matrix is generated with all same probability to show this.
|
||||
# Remember they could be holding up no fingers...
|
||||
# m = np.full((6,6), 0.2)
|
||||
|
||||
# Y1T = np.full((6,1),1)
|
||||
|
||||
# Compute consensus state estimate by taking difference between our observations
|
||||
# and all others individually.
|
||||
|
||||
# Moving to an approach that does not require the previous
|
||||
# timestep (or so much math...)
|
||||
# First take other information and fuse, using algorithm
|
||||
# as appropriate.
|
||||
pass
|
||||
|
||||
def custom_fuse(self):
|
||||
vi = np.zeros(6,1)
|
||||
# Set correct vi.
|
||||
vote = self.on_vote()
|
||||
vi[vote] = 1
|
||||
|
||||
|
||||
def on_message(self, client, userdata, message):
|
||||
try:
|
||||
message_dict = umsgpack.unpackb(message.payload)
|
||||
except:
|
||||
print("Incorrect message received")
|
||||
return
|
||||
|
||||
if message_dict["type"] == "vote":
|
||||
# received a vote
|
||||
if self._taking_votes:
|
||||
self._votes[message_dict["client"]] = message_dict["vote"]
|
||||
|
||||
elif message_dict["type"] == "connect":
|
||||
# voter connected to the swarm
|
||||
self._connected_voters.append(message_dict["client"])
|
||||
|
||||
elif message_dict["type"] == "disconnect":
|
||||
# Sent as the voter's will message
|
||||
self._connected_voters.remove(message_dict["client"])
|
||||
|
||||
def on_connect(self, client, userdata, flags, rc):
|
||||
print("Connected with result code " + str(rc))
|
||||
self._client.subscribe(self._swarm)
|
||||
|
||||
def collect_vote(self):
|
||||
vote_message = umsgpack.packb({"type": "vote",
|
||||
"client":self._client._client_id, "vote": self.on_vote()})
|
||||
return vote_message
|
||||
|
||||
def start_vote(self):
|
||||
pass
|
||||
0
car/src/DecisionSystem/__init__.py
Normal file
0
car/src/DecisionSystem/__init__.py
Normal file
101
car/src/DecisionSystem/messages.py
Normal file
101
car/src/DecisionSystem/messages.py
Normal file
@@ -0,0 +1,101 @@
|
||||
import umsgpack
|
||||
import uuid
|
||||
|
||||
class Message:
|
||||
_type = None
|
||||
def __init__(self, sender = "", data = {}):
|
||||
self._sender = sender
|
||||
self._data = data
|
||||
|
||||
@property
|
||||
def sender(self):
|
||||
return self._sender
|
||||
|
||||
@sender.setter
|
||||
def sender(self, value):
|
||||
self._sender = value
|
||||
|
||||
@property
|
||||
def data(self):
|
||||
return self._data
|
||||
|
||||
# I love using keywords...
|
||||
@property
|
||||
def type(self):
|
||||
return self._type
|
||||
|
||||
@type.setter
|
||||
def type(self, value):
|
||||
self._type = value
|
||||
|
||||
def serialise(self):
|
||||
return umsgpack.packb({"type":self.type, "sender": self.sender, "data": self.data})
|
||||
|
||||
# SHould make this static in Message class.
|
||||
def deserialise(obj):
|
||||
"""
|
||||
Deserialises a given messagepack object into a Message.
|
||||
"""
|
||||
m = Message()
|
||||
unpacked = umsgpack.unpackb(obj)
|
||||
print('Unpacked Object')
|
||||
print(unpacked)
|
||||
m.type = unpacked["type"]
|
||||
m._sender = unpacked["sender"]
|
||||
m._data = unpacked["data"]
|
||||
return m
|
||||
|
||||
class RequestLeader(Message):
|
||||
_type = "RequestLeader"
|
||||
|
||||
class ProposeMessage(Message):
|
||||
_type = "Propose"
|
||||
|
||||
class ElectionVote(Message):
|
||||
_type = "Elect"
|
||||
|
||||
class Commit(Message):
|
||||
_type = "Commit"
|
||||
|
||||
class ConnectSwarm(Message):
|
||||
_type = "connect"
|
||||
|
||||
class RequestVote(Message):
|
||||
_type = "reqvote"
|
||||
|
||||
class ConnectResponse(Message):
|
||||
_type = "accepted"
|
||||
|
||||
class VoterWill(Message):
|
||||
_type = "disconnectedvoter"
|
||||
|
||||
class CommanderWill(Message):
|
||||
_type = "disconnectedcommander"
|
||||
|
||||
class SubmitVote(Message):
|
||||
_type = "vote"
|
||||
|
||||
def __init__(self, vote = None, sender = "", data = {}):
|
||||
Message.__init__(self, sender, data)
|
||||
self._data["vote"] = vote
|
||||
|
||||
@property
|
||||
def vote(self):
|
||||
return self._data["vote"]
|
||||
|
||||
@vote.setter
|
||||
def vote(self, value):
|
||||
self._data["vote"] = value
|
||||
|
||||
class GetSwarmParticipants(Message):
|
||||
_type = "listening"
|
||||
|
||||
class VoteResult(Message):
|
||||
_type = "voteresult"
|
||||
|
||||
def __init__(self, vote, sender='', data={}):
|
||||
super().__init__(sender=sender, data=data)
|
||||
self._data["vote"] = vote
|
||||
|
||||
class ClientVoteRequest(Message):
|
||||
_type = "clientvoterequest"
|
||||
121
car/src/GestureRecognition/HandRecHSV.py
Normal file
121
car/src/GestureRecognition/HandRecHSV.py
Normal file
@@ -0,0 +1,121 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Created on Thu Nov 22 10:51:21 2018
|
||||
|
||||
@author: pivatom
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import cv2
|
||||
|
||||
img = cv2.imread('H:\car\GestureRecognition\IMG_0825.jpg', 1)
|
||||
# img = cv2.imread('H:\car\GestureRecognition\IMG_0818.png', 1)
|
||||
|
||||
# Downscale the image
|
||||
img = cv2.resize(img, None, fx=0.1, fy=0.1, interpolation = cv2.INTER_AREA)
|
||||
|
||||
e1 = cv2.getTickCount()
|
||||
|
||||
# Hand Localization... possibly with YOLOv3? v2 is faster though...
|
||||
|
||||
|
||||
img_hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
|
||||
|
||||
# Need to shift red pixels so they can be 0-20 rather than 250-~20
|
||||
img_hsv[:,:,0] = img_hsv[:,:,0] + 30
|
||||
img_hsv[:,:,0] = np.where(img_hsv[:,:,0] > 179, img_hsv[:,:,0] - 179, img_hsv[:,:,0])
|
||||
|
||||
img_hsv = cv2.GaussianBlur(img_hsv,(5,5),0)
|
||||
|
||||
lower_skin = (0, 0, 153)
|
||||
upper_skin = (45, 153, 255)
|
||||
|
||||
# Only need mask, as we can just use this to do the hand segmentation.
|
||||
mask = cv2.inRange(img_hsv, lower_skin, upper_skin)
|
||||
|
||||
# This takes a whole millisecond (approx), and does not seem very worth the cost.
|
||||
blur = cv2.GaussianBlur(mask,(5,5),0)
|
||||
ret, img_thresh = cv2.threshold(blur, 50, 255, cv2.THRESH_BINARY)
|
||||
|
||||
# Uncomment if not using blur and threshold.
|
||||
# img_thresh = mask
|
||||
|
||||
k = np.sum(img_thresh) / 255
|
||||
|
||||
# Taking indices for num of rows.
|
||||
x_ind = np.arange(0,img_thresh.shape[1])
|
||||
y_ind = np.arange(0,img_thresh.shape[0])
|
||||
coords_x = np.zeros((img_thresh.shape[0], img_thresh.shape[1]))
|
||||
coords_y = np.zeros((img_thresh.shape[0], img_thresh.shape[1]))
|
||||
coords_x[:,:] = x_ind
|
||||
|
||||
|
||||
# Even this is extremely quick as it goes through rows in the numpy array, which in python is much faster than columns
|
||||
for element in y_ind:
|
||||
coords_y[element,:] = element
|
||||
|
||||
# Now need to get the average x value and y value for centre of gravity
|
||||
xb = int(np.sum(coords_x[img_thresh == 255])/k)
|
||||
yb = int(np.sum(coords_y[img_thresh == 255])/k)
|
||||
|
||||
centre = (int(np.sum(coords_x[img_thresh == 255])/k), int(np.sum(coords_y[img_thresh == 255])/k))
|
||||
|
||||
# Calculate radius of circle:
|
||||
# May need to calculate diameter as well.
|
||||
# Just take min/max x values and y values
|
||||
x_min = np.min(coords_x[img_thresh == 255])
|
||||
x_max = np.max(coords_x[img_thresh == 255])
|
||||
y_min = np.min(coords_y[img_thresh == 255])
|
||||
y_max = np.max(coords_y[img_thresh == 255])
|
||||
|
||||
candidate_pts = [(x_min, y_min), (x_min, y_max), (x_max, y_min), (x_max, y_max)]
|
||||
radius = 0
|
||||
|
||||
# Check with each point to see which is furthest from the centre.
|
||||
for pt in candidate_pts:
|
||||
# Calculate Euclydian Distance
|
||||
new_distance = ((pt[0] - centre[0])**2 + (pt[1] - centre[1])**2)**(1/2)
|
||||
if new_distance > radius:
|
||||
radius = new_distance
|
||||
|
||||
radius = int(radius * 0.52)
|
||||
|
||||
# 140 needs to be replaced with a predicted value. i.e. not be a magic number.
|
||||
# cv2.circle(img_thresh, centre, radius, (120,0,0), 3)
|
||||
|
||||
def calc_pos_y(x):
|
||||
return int((radius**2 - (x - centre[0])**2)**(1/2) + centre[1])
|
||||
|
||||
# Now go around the circle to calculate num of times going 0->255 or vice-versa.
|
||||
# First just do it the naive way with loops.
|
||||
# Equation of the circle:
|
||||
# y = sqrt(r2 - (x-c)2) + c
|
||||
# Will just increment x to check, no need to loop y as well.
|
||||
# This is extremely slow, need to speed it up by removing for loop.
|
||||
# Brings speed down to 20 fps.
|
||||
# This is actually fast, it was just the print debug statements that made it slow, takes just 6ms...
|
||||
# Could try a kerel method?
|
||||
prev_x = centre[0] - radius
|
||||
prev_y = [calc_pos_y(centre[0] - radius), calc_pos_y(centre[0] - radius)]
|
||||
num_change = 0
|
||||
for x in range(centre[0] - radius + 1, centre[0] + radius):
|
||||
ypos = calc_pos_y(x)
|
||||
y = [ypos, centre[1] - (ypos-centre[1])]
|
||||
if(img_thresh[y[0], x] != img_thresh[prev_y[0], prev_x]):
|
||||
num_change += 1
|
||||
if img_thresh[y[1], x] != img_thresh[prev_y[1], prev_x] and y[0] != y[1]:
|
||||
num_change += 1
|
||||
prev_x = x
|
||||
prev_y = y
|
||||
|
||||
fingers = num_change / 2 - 1
|
||||
|
||||
print("Num Fingers: " + str(fingers))
|
||||
|
||||
e2 = cv2.getTickCount()
|
||||
t = (e2 - e1)/cv2.getTickFrequency()
|
||||
print( t )
|
||||
|
||||
cv2.imshow("Threshold", img_thresh)
|
||||
cv2.waitKey(0)
|
||||
cv2.destroyAllWindows()
|
||||
49
car/src/GestureRecognition/HandRecV2.py
Normal file
49
car/src/GestureRecognition/HandRecV2.py
Normal file
@@ -0,0 +1,49 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Created on Thu Nov 22 09:21:04 2018
|
||||
|
||||
@author: pivatom
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import cv2
|
||||
|
||||
min_seg_threshold = 1.05
|
||||
max_seg_threshold = 4
|
||||
|
||||
def calcSkinSample(event, x, y, flags, param):
|
||||
if event == cv2.EVENT_FLAG_LBUTTON:
|
||||
sample = img[x:x+10, y:y+10]
|
||||
min = 255
|
||||
max = 0
|
||||
for line in sample:
|
||||
avg = np.sum(line)/10
|
||||
if avg < min:
|
||||
min = avg
|
||||
if avg > max:
|
||||
max = avg
|
||||
min_seg_threshold = min
|
||||
max_seg_threshold = max
|
||||
|
||||
def draw_rect(event, x, y, flags, param):
|
||||
if event == cv2.EVENT_FLAG_LBUTTON:
|
||||
print("LbuttonClick")
|
||||
cv2.rectangle(img, (x,y), (x+10, y+10), (0,0,255), 3)
|
||||
|
||||
img = cv2.imread('H:\car\GestureRecognition\IMG_0818.png', 1)
|
||||
|
||||
# Downscale the image
|
||||
img = cv2.resize(img, None, fx=0.1, fy=0.1, interpolation = cv2.INTER_AREA)
|
||||
|
||||
cv2.namedWindow("Hand")
|
||||
cv2.setMouseCallback("Hand", draw_rect)
|
||||
|
||||
# prevent divide by zero, by just forcing pixel to be ignored.
|
||||
#np.where(img[:,:,1] == 0, 0, img[:,:,1])
|
||||
#img[(img[:,:,2]/img[:,:,1] > min_seg_threshold) & (img[:,:,2]/img[:,:,1] < max_seg_threshold)] = [255,255,255]
|
||||
|
||||
while(1):
|
||||
cv2.imshow("Hand", img)
|
||||
if cv2.waitKey(0):
|
||||
break
|
||||
cv2.destroyAllWindows()
|
||||
BIN
car/src/GestureRecognition/IMG_0818.png
Normal file
BIN
car/src/GestureRecognition/IMG_0818.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 13 MiB |
BIN
car/src/GestureRecognition/IMG_0825.jpg
Normal file
BIN
car/src/GestureRecognition/IMG_0825.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 1.9 MiB |
BIN
car/src/GestureRecognition/Neural Network hand Tracking.pdf
Normal file
BIN
car/src/GestureRecognition/Neural Network hand Tracking.pdf
Normal file
Binary file not shown.
381
car/src/GestureRecognition/SimpleHandRecogniser.py
Normal file
381
car/src/GestureRecognition/SimpleHandRecogniser.py
Normal file
@@ -0,0 +1,381 @@
|
||||
import numpy as np
|
||||
import cv2
|
||||
|
||||
from GestureRecognition.handrecogniser import HandRecogniser
|
||||
|
||||
class SimpleHandRecogniser(HandRecogniser):
|
||||
def __init__(self, frame):
|
||||
self.img = frame
|
||||
self.graph = None
|
||||
self.sess = None
|
||||
self.img_cut = None
|
||||
|
||||
def __calc_pos_y(self, x, radius, centre):
|
||||
"""
|
||||
Calculates the position of y on a given circle radius and centre, given coordinate x.
|
||||
"""
|
||||
return int((radius**2 - (x - centre[0])**2)**(1/2) + centre[1])
|
||||
|
||||
def __segment_image(self):
|
||||
"""
|
||||
Segments the hand from the rest of the image to get a threshold.
|
||||
"""
|
||||
self.img_cut = cv2.GaussianBlur(self.img_cut, (5, 5), 0)
|
||||
|
||||
lower_skin = (0, 0, 153)
|
||||
upper_skin = (45, 153, 255)
|
||||
|
||||
# Only need mask, as we can just use this to do the hand segmentation.
|
||||
self.img_cut = cv2.inRange(self.img_cut, lower_skin, upper_skin)
|
||||
|
||||
# Apply another blur to rmeove any small holes/noise
|
||||
self.img_cut = self.__denoise(self.img_cut)
|
||||
_, self.img_cut = cv2.threshold(self.img_cut, 50, 255, cv2.THRESH_BINARY)
|
||||
|
||||
def __denoise(self, image):
|
||||
"""
|
||||
Applies a 5x5 gaussian blur to remove noise from the image.
|
||||
"""
|
||||
return cv2.GaussianBlur(image, (5, 5), 0)
|
||||
|
||||
def __calc_circle(self, image, radius_percent=0.6):
|
||||
"""
|
||||
Calculates the equation of the circle (radius, centre) from a given
|
||||
threshold image, so that the circle is the center of gravity of the
|
||||
given threshold pixels, and the radius is by default 55% of the total
|
||||
size.
|
||||
"""
|
||||
k = np.sum(self.img_cut) / 255
|
||||
|
||||
# Taking indices for num of rows.
|
||||
x_ind = np.arange(0, self.img_cut.shape[1])
|
||||
y_ind = np.arange(0, self.img_cut.shape[0])
|
||||
coords_x = np.zeros((self.img_cut.shape[0], self.img_cut.shape[1]))
|
||||
coords_y = np.zeros((self.img_cut.shape[0], self.img_cut.shape[1]))
|
||||
coords_x[:, :] = x_ind
|
||||
|
||||
# Even this is extremely quick as it goes through rows in the numpy array,
|
||||
# which in python is much faster than columns
|
||||
for element in y_ind:
|
||||
coords_y[element, :] = element
|
||||
|
||||
# Now need to get the average x value and y value for centre of gravity
|
||||
centre = (int(np.sum(coords_x[self.img_cut == 255])/k), int(np.sum(coords_y[self.img_cut == 255])/k))
|
||||
|
||||
# Calculate radius of circle:
|
||||
# May need to calculate diameter as well.
|
||||
# Just take min/max x values and y values
|
||||
x_min = np.min(coords_x[self.img_cut == 255])
|
||||
x_max = np.max(coords_x[self.img_cut == 255])
|
||||
y_min = np.min(coords_y[self.img_cut == 255])
|
||||
y_max = np.max(coords_y[self.img_cut == 255])
|
||||
|
||||
candidate_pts = [(x_min, y_min), (x_min, y_max), (x_max, y_min), (x_max, y_max)]
|
||||
radius = 0
|
||||
|
||||
# Check with each point to see which is furthest from the centre.
|
||||
for pt in candidate_pts:
|
||||
# Calculate Euclydian Distance
|
||||
new_distance = ((pt[0] - centre[0])**2 + (pt[1] - centre[1])**2)**(1/2)
|
||||
if new_distance > radius:
|
||||
radius = new_distance
|
||||
|
||||
radius = int(radius * radius_percent)
|
||||
|
||||
return radius, centre
|
||||
|
||||
def __calc_circles(self, image, radius_percent_range=[0.6, 0.8], step = 0.1):
|
||||
"""
|
||||
Calculates the equation of the circle (radius, centre), but with
|
||||
several radii so that we can get a more accurate estimate of from a given
|
||||
threshold image, so that the circle is the center of gravity of the
|
||||
given threshold pixels.
|
||||
"""
|
||||
k = np.sum(self.img_cut) / 255
|
||||
|
||||
# Taking indices for num of rows.
|
||||
x_ind = np.arange(0,self.img_cut.shape[1])
|
||||
y_ind = np.arange(0,self.img_cut.shape[0])
|
||||
coords_x = np.zeros((self.img_cut.shape[0], self.img_cut.shape[1]))
|
||||
coords_y = np.zeros((self.img_cut.shape[0], self.img_cut.shape[1]))
|
||||
coords_x[:,:] = x_ind
|
||||
|
||||
# Even this is extremely quick as it goes through rows in the numpy array, which in python is much faster than columns
|
||||
for element in y_ind:
|
||||
coords_y[element,:] = element
|
||||
|
||||
# Now need to get the average x value and y value for centre of gravity
|
||||
centre = (int(np.sum(coords_x[self.img_cut == 255])/k), int(np.sum(coords_y[self.img_cut == 255])/k))
|
||||
|
||||
# Calculate radius of circle:
|
||||
# May need to calculate diameter as well.
|
||||
# Just take min/max x values and y values
|
||||
x_min = np.min(coords_x[self.img_cut == 255])
|
||||
x_max = np.max(coords_x[self.img_cut == 255])
|
||||
y_min = np.min(coords_y[self.img_cut == 255])
|
||||
y_max = np.max(coords_y[self.img_cut == 255])
|
||||
|
||||
candidate_pts = [(x_min, y_min), (x_min, y_max), (x_max, y_min), (x_max, y_max)]
|
||||
radius = 0
|
||||
|
||||
# Check with each point to see which is furthest from the centre.
|
||||
for pt in candidate_pts:
|
||||
# Calculate Euclydian Distance
|
||||
new_distance = ((pt[0] - centre[0])**2 + (pt[1] - centre[1])**2)**(1/2)
|
||||
if new_distance > radius:
|
||||
radius = new_distance
|
||||
|
||||
radii = []
|
||||
for i in range(radius_percent_range[0], radius_percent_range[1], step):
|
||||
radii += int(radius * i)
|
||||
|
||||
return radii, centre
|
||||
|
||||
def __shift_pixels(self, image, shift_radius):
|
||||
image[:, :, 0] = image[:, :, 0] + shift_radius
|
||||
image[:, :, 0] = np.where(image[:, :, 0] > 179, image[:, :, 0] - 179, image[:, :, 0])
|
||||
return image
|
||||
|
||||
def set_frame(self, frame):
|
||||
self.img = frame
|
||||
|
||||
# Source: Victor Dibia
|
||||
# Link: https://github.com/victordibia/handtracking
|
||||
# Taken the code straight from his example, as it works perfectly. This is specifically
|
||||
# from the load_inference_graph method that he wrote, and will load the graph into
|
||||
# memory if one has not already been loaded for this object.
|
||||
# def load_inference_graph(self):
|
||||
# """Loads a tensorflow model checkpoint into memory"""
|
||||
|
||||
# if self.graph != None and self.sess != None:
|
||||
# # Don't load more than once, to save time...
|
||||
# return
|
||||
|
||||
# PATH_TO_CKPT = '/Users/piv/Documents/Projects/car/GestureRecognition/frozen_inference_graph.pb'
|
||||
# # load frozen tensorflow model into memory
|
||||
# detection_graph = tf.Graph()
|
||||
# with detection_graph.as_default():
|
||||
# od_graph_def = tf.GraphDef()
|
||||
# with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid:
|
||||
# serialized_graph = fid.read()
|
||||
# od_graph_def.ParseFromString(serialized_graph)
|
||||
# tf.import_graph_def(od_graph_def, name='')
|
||||
# sess = tf.Session(graph=detection_graph)
|
||||
# self.graph = detection_graph
|
||||
# self.sess = sess
|
||||
|
||||
|
||||
# Source: Victor Dibia
|
||||
# Link: https://github.com/victordibia/handtracking
|
||||
# Taken the code straight from his example, as it works perfectly. This is specifically
|
||||
# from the detect_hand method that he wrote, as other processing is required for the
|
||||
# hand recognition to work correctly.
|
||||
# def detect_hand_tensorflow(self, detection_graph, sess):
|
||||
# """ Detects hands in a frame using a CNN
|
||||
|
||||
# detection_graph -- The CNN to use to detect the hand.
|
||||
# sess -- THe tensorflow session for the given graph
|
||||
# """
|
||||
|
||||
# image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')
|
||||
|
||||
# detection_boxes = detection_graph.get_tensor_by_name('detection_boxes:0')
|
||||
|
||||
# detection_scores = detection_graph.get_tensor_by_name('detection_scores:0')
|
||||
|
||||
# detection_classes = detection_graph.get_tensor_by_name('detection_classes:0')
|
||||
|
||||
# num_detections = detection_graph.get_tensor_by_name('num_detections:0')
|
||||
|
||||
# img_expanded = np.expand_dims(self.img, axis=0)
|
||||
|
||||
# (boxes, scores, classes, num) = sess.run(
|
||||
# [detection_boxes, detection_scores, detection_classes, num_detections],
|
||||
# feed_dict={image_tensor: img_expanded})
|
||||
# print('finished detection')
|
||||
# return np.squeeze(boxes), np.squeeze(scores)
|
||||
|
||||
def load_cv_net(self, graph_path, names_path):
|
||||
"""Loads a tensorflow neural object detection network using openCV
|
||||
|
||||
Arguments
|
||||
graph_path: Path to the tensorflow frozen inference graph (something.pb)
|
||||
names_path: Path to the tensorflow (something.pbtext) file.
|
||||
"""
|
||||
self.net = cv2.dnn.readNetFromTensorflow(graph_path, names_path)
|
||||
|
||||
def detect_hand_opencv(self):
|
||||
"""Performs hand detection using a CNN from tensorflow using opencv.
|
||||
|
||||
detection_graph -- The CNN to use to detect the hand.
|
||||
sess -- THe tensorflow session for the given graph
|
||||
"""
|
||||
if self.img is None:
|
||||
return
|
||||
|
||||
rows = self.img.shape[0]
|
||||
cols = self.img.shape[1]
|
||||
|
||||
self.net.setInput(cv2.dnn.blobFromImage(self.img, size=(300, 300), swapRB=True, crop=False))
|
||||
cv_out = self.net.forward()
|
||||
|
||||
boxes = []
|
||||
scores = []
|
||||
|
||||
for detection in cv_out[0, 0, :, :]:
|
||||
score = float(detection[2])
|
||||
# TODO: Need to make this the confidence threshold...
|
||||
if score > 0.6:
|
||||
left = detection[3] * cols
|
||||
top = detection[4] * rows
|
||||
right = detection[5] * cols
|
||||
bottom = detection[6] * rows
|
||||
boxes.append((left, top, right, bottom))
|
||||
scores.append(score)
|
||||
else:
|
||||
# Scores are in descending order...
|
||||
break
|
||||
|
||||
return boxes, scores
|
||||
|
||||
def get_best_hand(self, boxes, scores, conf_thresh, nms_thresh):
|
||||
"""
|
||||
Gets the best hand bounding box by inspecting confidence scores and overlapping
|
||||
boxes, as well as the overall size of each box to determine which hand (if multiple present)
|
||||
should be tested to recognise.
|
||||
"""
|
||||
print(scores)
|
||||
boxes = boxes[scores > conf_thresh]
|
||||
scores = scores[scores > conf_thresh]
|
||||
# Use NMS to get rid of heavily overlapping boxes.
|
||||
# This wasn't used in the tensorflow example that was found, however probably a
|
||||
# good idea to use it just in case.
|
||||
print(boxes.shape)
|
||||
if boxes.shape[0] == 0:
|
||||
print("No good boxes found")
|
||||
return None
|
||||
elif boxes.shape[0] == 1:
|
||||
print("Only one good box!")
|
||||
box = boxes[0]
|
||||
box[0] = box[0] * self.img.shape[0]
|
||||
box[1] = box[1] * self.img.shape[1]
|
||||
box[2] = box[2] * self.img.shape[0]
|
||||
box[3] = box[3] * self.img.shape[1]
|
||||
return box.astype(int)
|
||||
else:
|
||||
boxes[:][2] = ((boxes[:][2] - boxes[:][0]) * self.img.shape[0]).astype(int)
|
||||
boxes[:][3] = ((boxes[:][3] - boxes[:][1]) * self.img.shape[1]).astype(int)
|
||||
boxes[:][0] = (boxes[:][0] * self.img.shape[0]).astype(int)
|
||||
boxes[:][1] = (boxes[:][1] * self.img.shape[1]).astype(int)
|
||||
|
||||
# Can't seem to get this to work...
|
||||
# indices = cv2.dnn.NMSBoxes(boxes, scores, conf_thresh, nms_thresh)
|
||||
|
||||
print("Num boxes: %s" % boxes.shape[0])
|
||||
# Finally calculate area of each box to determine which hand is clearest (biggest in image)
|
||||
# Just does the most confident for now.
|
||||
best_box = boxes[0]
|
||||
best_index = None
|
||||
i = 0
|
||||
for box in boxes:
|
||||
if box[2] * box[3] > best_box[2] * best_box[3]:
|
||||
best_box = box
|
||||
best_index = i
|
||||
i += 1
|
||||
return boxes[i - 1]
|
||||
|
||||
def get_gesture(self):
|
||||
"""
|
||||
Calculates the actual gesture, returning the number of fingers
|
||||
seen in the image.
|
||||
"""
|
||||
print('Getting Gesture')
|
||||
if self.img is None:
|
||||
print('There is no image')
|
||||
return -1
|
||||
# First cut out the frame using the neural network.
|
||||
# self.load_inference_graph()
|
||||
# print("loaded inference graph")
|
||||
# detections, scores = self.detect_hand_tensorflow(self.graph, self.sess)
|
||||
|
||||
print('Loading openCV net')
|
||||
self.load_cv_net('/Users/piv/Documents/Projects/car/GestureRecognition/frozen_inference_graph.pb',
|
||||
'/Users/piv/Documents/Projects/car/GestureRecognition/graph.pbtxt')
|
||||
|
||||
detections, scores = self.detect_hand_opencv()
|
||||
|
||||
# print("Getting best hand")
|
||||
# best_hand = self.get_best_hand(detections, scores, 0.7, 0.5)
|
||||
# if best_hand is not None:
|
||||
# self.img = self.img[best_hand[0] - 30:best_hand[2] + 30, best_hand[1] - 30:best_hand[3] + 30]
|
||||
|
||||
if len(detections) > 0:
|
||||
print("Cutting out the hand!")
|
||||
self.img_cut = self.img[detections[0] - 30:detections[2] + 30, detections[1] - 30:detections[3] + 30]
|
||||
else:
|
||||
self.img_cut = self.img
|
||||
|
||||
print('Attempting to use pure hand recognition')
|
||||
self.img_cut = cv2.cvtColor(self.img_cut, cv2.COLOR_BGR2HSV)
|
||||
|
||||
# Need to shift red pixels so they can be 0-20 rather than 250-~20
|
||||
self.img_cut = self.__shift_pixels(self.img_cut, 30)
|
||||
|
||||
self.img_cut = self.__denoise(self.img_cut)
|
||||
self.__segment_image()
|
||||
|
||||
print('calculating circle')
|
||||
# Could calculate multiple circles to get probability
|
||||
# for each gesture (i.e. calc num of each gesture recongised and take percentage
|
||||
# as the probability).
|
||||
radius, centre = self.__calc_circle(self.img_cut)
|
||||
print('Got circle')
|
||||
|
||||
# Now go around the circle to calculate num of times going 0->255 or vice-versa.
|
||||
# First just do it the naive way with loops.
|
||||
# Equation of the circle:
|
||||
# y = sqrt(r2 - (x-c)2) + c
|
||||
prev_x = centre[0] - radius
|
||||
prev_y = [self.__calc_pos_y(centre[0] - radius, radius, centre),
|
||||
self.__calc_pos_y(centre[0] - radius, radius, centre)]
|
||||
num_change = 0
|
||||
|
||||
# Make sure x is also within bounds.
|
||||
x_start = centre[0] - radius + 1
|
||||
if x_start < 0:
|
||||
x_start = 0
|
||||
|
||||
x_end = centre[0] + radius
|
||||
if x_end >= self.img_cut.shape[1]:
|
||||
x_end = self.img_cut.shape[1] - 1
|
||||
|
||||
for x in range(x_start, x_end):
|
||||
# Need to check circle is inside the bounds.
|
||||
ypos = self.__calc_pos_y(x, radius, centre)
|
||||
# y above centre (ypos) and y below radius)
|
||||
y = [ypos, centre[1] - (ypos-centre[1])]
|
||||
|
||||
if y[0] < 0:
|
||||
y[0] = 0
|
||||
if y[0] >= self.img_cut.shape[0]:
|
||||
y[0] = self.img_cut.shape[0] - 1
|
||||
if y[1] < 0:
|
||||
y[1] = 0
|
||||
if y[1] >= self.img_cut.shape[0]:
|
||||
y[1] = self.img_cut.shape[0] - 1
|
||||
if(self.img_cut[y[0], x] != self.img_cut[prev_y[0], prev_x]):
|
||||
num_change += 1
|
||||
if self.img_cut[y[1], x] != self.img_cut[prev_y[1], prev_x] and y[0] != y[1]:
|
||||
num_change += 1
|
||||
prev_x = x
|
||||
prev_y = y
|
||||
|
||||
print('Finished calculating, returning')
|
||||
print(num_change)
|
||||
return int(num_change / 2 - 1), self.img
|
||||
|
||||
def get_gesture_multiple_radii(self):
|
||||
pass
|
||||
|
||||
def calc_hand_batch(self, batch):
|
||||
pass
|
||||
0
car/src/GestureRecognition/__init__.py
Normal file
0
car/src/GestureRecognition/__init__.py
Normal file
BIN
car/src/GestureRecognition/frozen_inference_graph.pb
Normal file
BIN
car/src/GestureRecognition/frozen_inference_graph.pb
Normal file
Binary file not shown.
3146
car/src/GestureRecognition/graph.pbtxt
Normal file
3146
car/src/GestureRecognition/graph.pbtxt
Normal file
File diff suppressed because it is too large
Load Diff
15
car/src/GestureRecognition/handrecogniser.py
Normal file
15
car/src/GestureRecognition/handrecogniser.py
Normal file
@@ -0,0 +1,15 @@
|
||||
class HandRecogniser:
|
||||
"""
|
||||
Interface for Recognising simple hand gestures from an image (or frame of a video)
|
||||
"""
|
||||
def load_image(self, image_path = ""):
|
||||
"""
|
||||
Loads the given image, can be lazy loading.
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_gesture(self):
|
||||
"""
|
||||
Gets a the gesture recognised in the image.
|
||||
"""
|
||||
pass
|
||||
73
car/src/GestureRecognition/kaleidoscope.py
Normal file
73
car/src/GestureRecognition/kaleidoscope.py
Normal file
@@ -0,0 +1,73 @@
|
||||
import numpy as np
|
||||
import cv2
|
||||
|
||||
def make_triangle(start_img):
|
||||
h, w, d = start_img.shape
|
||||
|
||||
#crop square
|
||||
inset = int((max(w,h) - min(w,h)) / 2)
|
||||
# sqrimg = start_img.crop(inset, inset, h-inset, w-inset)
|
||||
insetW = inset if w > h else 0
|
||||
insetH = inset if h > w else 0
|
||||
sqrimg = start_img[insetH:h-insetH, insetW:w-insetW]
|
||||
|
||||
#solve equilateral triangle
|
||||
w, h, d = sqrimg.shape
|
||||
print((w,h))
|
||||
|
||||
mask = np.zeros((w,h,d))
|
||||
|
||||
t_height = w/2 * np.tan(60)
|
||||
pts = np.array([[0,w],[h/2,t_height],[h,w]], np.int32)
|
||||
pts = pts.reshape((-1,1,2))
|
||||
mask = cv2.fillPoly(mask, [pts], (255,0,0))
|
||||
|
||||
# With mask, get the triangle from the original image.
|
||||
sqrimg[:,:,0] = np.where(mask[:,:,0] == 255, sqrimg[:,:,0], 0)
|
||||
sqrimg[:,:,1] = np.where(mask[:,:,0] == 255, sqrimg[:,:,1], 0)
|
||||
sqrimg[:,:,2] = np.where(mask[:,:,0] == 255, sqrimg[:,:,2], 0)
|
||||
return sqrimg
|
||||
|
||||
def rotate(im, rotation):
|
||||
M = cv2.getRotationMatrix2D((im.shape[1]/2,im.shape[0]/2),rotation,1)
|
||||
im[:,:,0] = cv2.warpAffine(im[:,:,0],M,(im.shape[1],im.shape[0]))
|
||||
im[:,:,1] = cv2.warpAffine(im[:,:,1],M,(im.shape[1],im.shape[0]))
|
||||
im[:,:,2] = cv2.warpAffine(im[:,:,2],M,(im.shape[1],im.shape[0]))
|
||||
return im
|
||||
|
||||
def make_kaleidoscope(img):
|
||||
triangle = make_triangle(img)
|
||||
|
||||
def make_trapezoid(triangle, save=False):
|
||||
|
||||
w, h = triangle.size
|
||||
can_w, can_h = w*3, h
|
||||
output = np.array((can_w, can_h, 3))
|
||||
output = Image.new('RGBA', (can_w, can_h), color=255)
|
||||
|
||||
def mirror_paste(last_img, coords):
|
||||
mirror = rotate(cv2.flip(last_img, 1), 60)
|
||||
output.paste(mirror, (coords), mirror)
|
||||
return mirror, coords
|
||||
|
||||
#paste in bottom left corner
|
||||
output.paste(triangle,(0, can_h-h), triangle)
|
||||
|
||||
last_img, coords = mirror_paste(triangle, (int(w/4.4), -int(h/2.125)))
|
||||
last_img, coords = mirror_paste(rotateIm(last_img, 120), (int(can_w/7.3), -228))
|
||||
|
||||
output = output.crop((0,15, w*2-22, h))
|
||||
if save:
|
||||
path = 'output/trapezoid_{}'.format(filename.split('/')[1])
|
||||
output.save(path)
|
||||
return output, path
|
||||
return output
|
||||
|
||||
if __name__ == "__main__":
|
||||
img = cv2.imread("/Users/piv/Documents/Projects/car/GestureRecognition/IMG_0818.png")
|
||||
triangle = make_triangle(img)
|
||||
triangle = cv2.resize(triangle, None, fx=0.3, fy=0.3, interpolation = cv2.INTER_AREA)
|
||||
triangle = rotate(triangle, 180)
|
||||
cv2.imshow("", triangle)
|
||||
cv2.waitKey(0)
|
||||
cv2.destroyAllWindows()
|
||||
28
car/src/GestureRecognition/keras_ex.py
Normal file
28
car/src/GestureRecognition/keras_ex.py
Normal file
@@ -0,0 +1,28 @@
|
||||
import time
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
|
||||
os.environ["KERAS_BACKEND"] = "plaidml.keras.backend"
|
||||
|
||||
import keras
|
||||
import keras.applications as kapp
|
||||
from keras.datasets import cifar10
|
||||
|
||||
(x_train, y_train_cats), (x_test, y_test_cats) = cifar10.load_data()
|
||||
batch_size = 8
|
||||
x_train = x_train[:batch_size]
|
||||
x_train = np.repeat(np.repeat(x_train, 7, axis=1), 7, axis=2)
|
||||
model = kapp.VGG19()
|
||||
model.compile(optimizer='sgd', loss='categorical_crossentropy',
|
||||
metrics=['accuracy'])
|
||||
|
||||
print("Running initial batch (compiling tile program)")
|
||||
y = model.predict(x=x_train, batch_size=batch_size)
|
||||
|
||||
# Now start the clock and run 10 batches
|
||||
print("Timing inference...")
|
||||
start = time.time()
|
||||
for i in range(10):
|
||||
y = model.predict(x=x_train, batch_size=batch_size)
|
||||
print("Ran in {} seconds".format(time.time() - start))
|
||||
23
car/src/GestureRecognition/opencvtensorflowex.py
Normal file
23
car/src/GestureRecognition/opencvtensorflowex.py
Normal file
@@ -0,0 +1,23 @@
|
||||
import cv2 as cv
|
||||
|
||||
cvNet = cv.dnn.readNetFromTensorflow('frozen_inference_graph.pb', 'graph.pbtxt')
|
||||
|
||||
img = cv.imread('IMG_0825.jpg')
|
||||
img = cv.resize(img, None, fx=0.1, fy=0.1, interpolation = cv.INTER_AREA)
|
||||
rows = img.shape[0]
|
||||
cols = img.shape[1]
|
||||
print(str(rows) + " " + str(cols))
|
||||
cvNet.setInput(cv.dnn.blobFromImage(img, size=(300, 300), swapRB=True, crop=False))
|
||||
cvOut = cvNet.forward()
|
||||
|
||||
for detection in cvOut[0,0,:,:]:
|
||||
score = float(detection[2])
|
||||
if score > 0.6:
|
||||
left = detection[3] * cols
|
||||
top = detection[4] * rows
|
||||
right = detection[5] * cols
|
||||
bottom = detection[6] * rows
|
||||
cv.rectangle(img, (int(left), int(top)), (int(right), int(bottom)), (23, 230, 210), thickness=2)
|
||||
|
||||
cv.imshow('img', img)
|
||||
cv.waitKey()
|
||||
58
car/src/GestureRecognition/starkaleid.py
Normal file
58
car/src/GestureRecognition/starkaleid.py
Normal file
@@ -0,0 +1,58 @@
|
||||
import numpy as np
|
||||
import cv2
|
||||
|
||||
def make_triangle(img, num_triangles):
|
||||
print(img.shape)
|
||||
y,x = (img.shape[0]//2, img.shape[1]//2)
|
||||
angles = 2 * np.pi/num_triangles
|
||||
print(angles/2)
|
||||
w,h,d = img.shape
|
||||
print(np.tan(angles/2))
|
||||
z = int(np.tan(angles/2) * (h/2))
|
||||
print(z)
|
||||
print(h)
|
||||
u = (x + z, y + h/2)
|
||||
v = (x - z, y + h/2)
|
||||
mask = np.zeros((w,h,d))
|
||||
|
||||
pts = np.array([v,(x,y),u], np.int32)
|
||||
pts = pts.reshape((-1,1,2))
|
||||
mask = cv2.fillPoly(mask, [pts], (255,0,0))
|
||||
|
||||
# With mask, get the triangle from the original image.
|
||||
img[:,:,0] = np.where(mask[:,:,0] == 255, img[:,:,0], 0)
|
||||
img[:,:,1] = np.where(mask[:,:,0] == 255, img[:,:,1], 0)
|
||||
img[:,:,2] = np.where(mask[:,:,0] == 255, img[:,:,2], 0)
|
||||
return img
|
||||
|
||||
def rotate(im, rotation):
|
||||
M = cv2.getRotationMatrix2D((im.shape[1]/2,im.shape[0]/2), rotation, 1)
|
||||
im[:,:,0] = cv2.warpAffine(im[:,:,0],M,(im.shape[1],im.shape[0]))
|
||||
im[:,:,1] = cv2.warpAffine(im[:,:,1],M,(im.shape[1],im.shape[0]))
|
||||
im[:,:,2] = cv2.warpAffine(im[:,:,2],M,(im.shape[1],im.shape[0]))
|
||||
return im
|
||||
|
||||
def _stitch(img, to_stitch):
|
||||
img[:,:,0] = np.where((img[:,:,0] == 0) & (to_stitch[:,:,0] != 0), to_stitch[:,:,0], img[:,:,0])
|
||||
img[:,:,1] = np.where((img[:,:,1] == 0) & (to_stitch[:,:,1] != 0), to_stitch[:,:,1], img[:,:,1])
|
||||
img[:,:,2] = np.where((img[:,:,2] == 0) & (to_stitch[:,:,2] != 0), to_stitch[:,:,2], img[:,:,2])
|
||||
|
||||
def make_kaleidoscope(img, num):
|
||||
triangle = make_triangle(img, num)
|
||||
iters = num
|
||||
while iters > 0:
|
||||
new_triangle = np.copy(triangle)
|
||||
new_triangle = cv2.flip(new_triangle, 1) if iters % 2 != 0 else new_triangle
|
||||
rotate(new_triangle, 360/num * iters)
|
||||
_stitch(triangle, new_triangle)
|
||||
iters -= 1
|
||||
return triangle
|
||||
|
||||
if __name__ == "__main__":
|
||||
img = cv2.imread("/Users/piv/Documents/Projects/car/GestureRecognition/IMG_0818.png")
|
||||
img = cv2.resize(img, None, fx=0.3, fy=0.3, interpolation = cv2.INTER_AREA)
|
||||
num = 12
|
||||
kaleid = make_kaleidoscope(img, num)
|
||||
cv2.imshow("", kaleid)
|
||||
cv2.waitKey(0)
|
||||
cv2.destroyAllWindows()
|
||||
0
car/src/Messaging/__init__.py
Normal file
0
car/src/Messaging/__init__.py
Normal file
64
car/src/Messaging/message_factory.py
Normal file
64
car/src/Messaging/message_factory.py
Normal file
@@ -0,0 +1,64 @@
|
||||
import zmq
|
||||
|
||||
|
||||
class ZmqPubSubStreamer:
|
||||
'''
|
||||
Not thread-safe. Always get this inside the thread/process where you intend
|
||||
to use it.
|
||||
'''
|
||||
|
||||
def __init__(self, port):
|
||||
self._socket = zmq.Context.instance().socket(zmq.PUB)
|
||||
print('Starting socket with address: ' + 'tcp://*:' + str(port))
|
||||
self._socket.bind("tcp://*:" + str(port))
|
||||
|
||||
|
||||
def send_message(self, message):
|
||||
'''
|
||||
Args
|
||||
----
|
||||
message: A message type that has the serialise() method.
|
||||
'''
|
||||
self.send_message_topic("", message)
|
||||
|
||||
def send_message_topic(self, topic, message):
|
||||
self._socket.send_multipart([bytes(topic), message.serialise()])
|
||||
|
||||
|
||||
class BluetoothStreamer:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def send_message(self, message_bytes):
|
||||
pass
|
||||
|
||||
class TestStreamer:
|
||||
def __init__(self):
|
||||
self._listeners = []
|
||||
|
||||
def send_message(self, message_bytes):
|
||||
print('Got a message')
|
||||
|
||||
def send_message_topic(self, topic, message):
|
||||
print('Got a message with topic: ' + str(topic))
|
||||
self._fire_message_received(message)
|
||||
|
||||
def add_message_listener(self, listener):
|
||||
self._listeners.append(listener)
|
||||
|
||||
def _fire_message_received(self, message):
|
||||
for listener in self._listeners:
|
||||
listener(message)
|
||||
|
||||
def getZmqPubSubStreamer(port):
|
||||
'''
|
||||
Not thread-safe. Always get this inside the thread/process where you intend
|
||||
to use it.
|
||||
'''
|
||||
return ZmqPubSubStreamer(port)
|
||||
|
||||
def getTestingStreamer():
|
||||
return TestStreamer()
|
||||
|
||||
# TODO: Create a general get method that will get the streamer based on an
|
||||
# environment variable that is set.
|
||||
34
car/src/Messaging/messages.py
Normal file
34
car/src/Messaging/messages.py
Normal file
@@ -0,0 +1,34 @@
|
||||
import umsgpack
|
||||
|
||||
|
||||
class Message():
|
||||
def __init__(self, message=None):
|
||||
self.message = message
|
||||
|
||||
def serialise(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def deserialise(self, message):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class PackMessage(Message):
|
||||
|
||||
def serialise(self):
|
||||
return umsgpack.packb(self.message)
|
||||
|
||||
def deserialise(self, message):
|
||||
return PackMessage(umsgpack.unpackb(self.message))
|
||||
|
||||
|
||||
class ProtoMessage(Message):
|
||||
|
||||
def __init__(self, proto_type=None, message=None):
|
||||
super().__init__(message)
|
||||
self._type = proto_type
|
||||
|
||||
def serialise(self):
|
||||
return self.message.SerializeToString()
|
||||
|
||||
def deserialise(self, message):
|
||||
return ProtoMessage(self._type, self._type.ParseFromString(message))
|
||||
64
car/src/Messaging/mqttsession.py
Normal file
64
car/src/Messaging/mqttsession.py
Normal file
@@ -0,0 +1,64 @@
|
||||
import paho.mqtt.client as mqtt
|
||||
|
||||
"""
|
||||
Wrapper module for paho mqtt library, providing a singleton instance of the client to be used.
|
||||
Also adds some convenience functions such as having multiple connected callbacks,
|
||||
and managing whether the client is still connected.
|
||||
"""
|
||||
|
||||
|
||||
client = mqtt.Client()
|
||||
host = None
|
||||
|
||||
connect_callbacks = []
|
||||
disconnect_callbacks = []
|
||||
|
||||
def on_connect(client, userdata, flags, rc):
|
||||
print("Connected with result code " + str(rc))
|
||||
if rc == 0:
|
||||
global connected
|
||||
connected = True
|
||||
|
||||
for callback in connect_callbacks:
|
||||
callback()
|
||||
|
||||
client.subscribe('hello/test', qos=1)
|
||||
|
||||
# Arguably not needed, just want to make the client static, but here anyway.
|
||||
def connect():
|
||||
global client
|
||||
if client is None or host is None:
|
||||
print("Error: Client and/or host are not initialised.")
|
||||
else:
|
||||
client.connect(host, port=1883, keepalive=60, bind_address="")
|
||||
client.loop_start()
|
||||
|
||||
def add_connect_callback(callback):
|
||||
global connect_callbacks
|
||||
connect_callbacks += callback
|
||||
connectted = True
|
||||
|
||||
def add_disconnect_callback(callback):
|
||||
global
|
||||
|
||||
def disconnect():
|
||||
global client
|
||||
if client is not None:
|
||||
client.loop_stop()
|
||||
client.disconnect()
|
||||
else:
|
||||
print("Error: Client is not initialised.")
|
||||
|
||||
def on_disconnect(client, userdata, rc):
|
||||
if rc != 0:
|
||||
print("Unexpected disconnection.")
|
||||
|
||||
global connected
|
||||
connected = False
|
||||
|
||||
def Client():
|
||||
global client
|
||||
if client is None:
|
||||
client = mqtt.Client()
|
||||
|
||||
return client
|
||||
0
car/src/__init__.py
Normal file
0
car/src/__init__.py
Normal file
33
car/src/control/PythonRemoteController.py
Normal file
33
car/src/control/PythonRemoteController.py
Normal file
@@ -0,0 +1,33 @@
|
||||
print("Connecting to pi")
|
||||
|
||||
import grpc
|
||||
from concurrent import futures
|
||||
import motorService_pb2_grpc
|
||||
from motorService_pb2 import SteeringRequest, ThrottleRequest
|
||||
import time
|
||||
|
||||
throttle = 0.1
|
||||
timer = None
|
||||
|
||||
class ThrottleIterator:
|
||||
'''
|
||||
Class to get the current throttle for the car.
|
||||
Will return a random throttle between
|
||||
'''
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
if throttle > 1 or throttle < -1:
|
||||
raise StopIteration
|
||||
time.sleep(1)
|
||||
return ThrottleRequest(throttle=throttle)
|
||||
|
||||
|
||||
channel = grpc.insecure_channel('10.0.0.53:50051')
|
||||
stub = motorService_pb2_grpc.CarControlStub(channel)
|
||||
|
||||
response = stub.SetThrottle(ThrottleIterator())
|
||||
|
||||
while True:
|
||||
throttle = int(input('Please enter a value for the throttle between -100 and 100'))
|
||||
0
car/src/control/__init__.py
Normal file
0
car/src/control/__init__.py
Normal file
0
car/src/control/gpio/__init__.py
Normal file
0
car/src/control/gpio/__init__.py
Normal file
42
car/src/control/gpio/mockvehicle.py
Normal file
42
car/src/control/gpio/mockvehicle.py
Normal file
@@ -0,0 +1,42 @@
|
||||
|
||||
|
||||
# A dummy vehicle class to use when
|
||||
class MockVehicle:
|
||||
def __init__(self, motor_pin=19, servo_pin=18):
|
||||
self.motor_pin = motor_pin
|
||||
self.steering_pin = servo_pin
|
||||
|
||||
@property
|
||||
def throttle(self):
|
||||
return self._throttle
|
||||
|
||||
@throttle.setter
|
||||
def throttle(self, value):
|
||||
self._throttle = value
|
||||
|
||||
@property
|
||||
def steering(self):
|
||||
return self._steering
|
||||
|
||||
@steering.setter
|
||||
def steering(self, value):
|
||||
self._steering = value
|
||||
|
||||
@property
|
||||
def motor_pin(self):
|
||||
return self._motor_pin
|
||||
|
||||
@motor_pin.setter
|
||||
def motor_pin(self, value):
|
||||
self._motor_pin = value
|
||||
|
||||
@property
|
||||
def steering_pin(self):
|
||||
return self._steering_pin
|
||||
|
||||
@steering_pin.setter
|
||||
def steering_pin(self, value):
|
||||
self._steering_pin = value
|
||||
|
||||
def stop(self):
|
||||
self.throttle = 0
|
||||
83
car/src/control/gpio/vehicle.py
Normal file
83
car/src/control/gpio/vehicle.py
Normal file
@@ -0,0 +1,83 @@
|
||||
from gpiozero import Servo, Device
|
||||
from gpiozero.pins.pigpio import PiGPIOFactory
|
||||
import subprocess
|
||||
|
||||
|
||||
def _safely_set_servo_value(servo, value):
|
||||
try:
|
||||
if value < -1 or value > 1:
|
||||
print("Not setting throttle, invalid value set.")
|
||||
return False
|
||||
servo.value = value
|
||||
except TypeError:
|
||||
print("throttle should be a number, preferably a float.")
|
||||
return False
|
||||
return True
|
||||
|
||||
def _is_pin_valid(pin):
|
||||
if isinstance(pin, int):
|
||||
if pin < 2 or pin > 21:
|
||||
print("Invalid GPIO pin")
|
||||
return False
|
||||
return True
|
||||
else:
|
||||
print("Value must be an int.")
|
||||
return False
|
||||
|
||||
# TODO: Allow a vector to be set to change the throttle/steering, for vehicles that don't use
|
||||
# two servos for controls (e.g. drone, dog)
|
||||
class Vehicle:
|
||||
def __init__(self, motor_pin=19, servo_pin=18):
|
||||
subprocess.call(['sudo', 'pigpiod'])
|
||||
Device.pin_factory = PiGPIOFactory()
|
||||
print('Using pin factory:')
|
||||
print(Device.pin_factory)
|
||||
self.motor_pin = motor_pin
|
||||
self.steering_pin = servo_pin
|
||||
self.initialise_motor()
|
||||
|
||||
def initialise_motor(self):
|
||||
self._motor_servo = Servo(
|
||||
self._motor_pin, pin_factory=Device.pin_factory)
|
||||
self._steering_servo = Servo(self._steering_pin, pin_factory=Device.pin_factory)
|
||||
|
||||
@property
|
||||
def throttle(self):
|
||||
return self._motor_servo.value
|
||||
|
||||
@throttle.setter
|
||||
def throttle(self, value):
|
||||
_safely_set_servo_value(self._motor_servo, value)
|
||||
|
||||
@property
|
||||
def steering(self):
|
||||
return self._motor_servo.value
|
||||
|
||||
@steering.setter
|
||||
def steering(self, value):
|
||||
_safely_set_servo_value(self._motor_servo, value)
|
||||
|
||||
@property
|
||||
def motor_pin(self):
|
||||
return self._motor_pin
|
||||
|
||||
@motor_pin.setter
|
||||
def motor_pin(self, value):
|
||||
# TODO: Reinitialise the servo when the pin changes, or discard this method
|
||||
# (probably don't want to allow pin changes whilst the device is in use anyway)
|
||||
self._motor_pin = value if _is_pin_valid(value) else self._motor_pin
|
||||
|
||||
@property
|
||||
def steering_pin(self):
|
||||
return self._steering_pin
|
||||
|
||||
@steering_pin.setter
|
||||
def steering_pin(self, value):
|
||||
self._steering_pin = value if _is_pin_valid(value) else self._steering_pin
|
||||
|
||||
def stop(self):
|
||||
self.throttle = 0
|
||||
self.steering = 0
|
||||
|
||||
def change_with_vector(self, vector):
|
||||
pass
|
||||
40
car/src/control/motor_servicer.py
Normal file
40
car/src/control/motor_servicer.py
Normal file
@@ -0,0 +1,40 @@
|
||||
from threading import Timer, Thread
|
||||
from concurrent import futures
|
||||
import time
|
||||
|
||||
import control.motorService_pb2 as motorService_pb2
|
||||
import control.motorService_pb2_grpc as motorService_pb2_grpc
|
||||
|
||||
class MotorServicer(motorService_pb2_grpc.CarControlServicer):
|
||||
def __init__(self, vehicle):
|
||||
self.vehicle = vehicle
|
||||
self._timer = None
|
||||
|
||||
def SetThrottle(self, request, context):
|
||||
# gRPC streams currently don't work between python and android.
|
||||
# If we don't get a response every 3 seconds, stop the car.
|
||||
print('Setting throttle to: ' + str(request.throttle))
|
||||
self.set_timeout(3)
|
||||
self.vehicle.throttle = request.throttle
|
||||
return motorService_pb2.ThrottleResponse(throttleSet=True)
|
||||
|
||||
def SetSteering(self, request, context):
|
||||
print('Setting steering to: ' + str(request.steering))
|
||||
self.vehicle.steering = request.steering
|
||||
return motorService_pb2.SteeringResponse(steeringSet=True)
|
||||
|
||||
def set_timeout(self, min_timeout):
|
||||
"""Stops the old timer and restarts it to the specified time.
|
||||
|
||||
min_timeout -- The minimum time that can be used for the timer.
|
||||
"""
|
||||
if self._timer is not None:
|
||||
self._timer.cancel()
|
||||
self._timer = Timer(min_timeout, self.timeout_elapsed)
|
||||
self._timer.start()
|
||||
|
||||
def timeout_elapsed(self):
|
||||
"""Election or heartbeat timeout has elapsed."""
|
||||
print("Node timeout elapsed")
|
||||
self.vehicle.stop()
|
||||
|
||||
63
car/src/controller.py
Executable file
63
car/src/controller.py
Executable file
@@ -0,0 +1,63 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
from threading import Timer, Thread
|
||||
from concurrent import futures
|
||||
import time
|
||||
|
||||
import grpc
|
||||
|
||||
import control.motorService_pb2_grpc as motorService_pb2_grpc
|
||||
from control.gpio.vehicle import Vehicle
|
||||
from control.motor_servicer import MotorServicer
|
||||
from slam.slam_servicer import SlamServicer
|
||||
import slam.SlamController_pb2_grpc as SlamController_pb2_grpc
|
||||
import tracking.lidar_tracker_pb2_grpc as lidar_tracker_pb2_grpc
|
||||
from tracking.lidar_servicer import LidarServicer
|
||||
|
||||
|
||||
class CarServer():
|
||||
|
||||
def __init__(self, vehicle):
|
||||
self.vehicle = vehicle
|
||||
|
||||
def start_server(self):
|
||||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=8))
|
||||
motorService_pb2_grpc.add_CarControlServicer_to_server(self.create_motor_servicer(), server)
|
||||
SlamController_pb2_grpc.add_SlamControlServicer_to_server(
|
||||
self.create_slam_servicer(), server)
|
||||
lidar_tracker_pb2_grpc.add_PersonTrackingServicer_to_server(
|
||||
self.create_lidar_servicer(), server)
|
||||
# Disable tls for local testing.
|
||||
# server.add_secure_port('[::]:50051', self.create_credentials())
|
||||
server.add_insecure_port('[::]:50051')
|
||||
server.start()
|
||||
while True:
|
||||
time.sleep(60*60)
|
||||
|
||||
def create_motor_servicer(self):
|
||||
return MotorServicer(self.vehicle)
|
||||
|
||||
def create_slam_servicer(self):
|
||||
return SlamServicer()
|
||||
|
||||
def create_lidar_servicer(self):
|
||||
return LidarServicer()
|
||||
|
||||
def create_credentials(self):
|
||||
# Relativise this stuff.
|
||||
pvtKeyPath = '/home/pi/tls/device.key'
|
||||
pvtCertPath = '/home/pi/tls/device.crt'
|
||||
|
||||
pvtKeyBytes = open(pvtKeyPath, 'rb').read()
|
||||
pvtCertBytes = open(pvtCertPath, 'rb').read()
|
||||
|
||||
return grpc.ssl_server_credentials([[pvtKeyBytes, pvtCertBytes]])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
vehicle = Vehicle()
|
||||
server = CarServer(vehicle)
|
||||
|
||||
# Can't remember why I do this, is it even needed?
|
||||
service_thread = Thread(target=server.start_server)
|
||||
service_thread.start()
|
||||
0
car/src/slam/__init__.py
Normal file
0
car/src/slam/__init__.py
Normal file
31
car/src/slam/slam_servicer.py
Normal file
31
car/src/slam/slam_servicer.py
Normal file
@@ -0,0 +1,31 @@
|
||||
import slam.SlamController_pb2_grpc as grpc
|
||||
import slam.SlamController_pb2 as proto
|
||||
import slam.slam_streamer as slam
|
||||
from multiprocessing import Process
|
||||
|
||||
|
||||
class SlamServicer(grpc.SlamControlServicer):
|
||||
slam_thread = None
|
||||
|
||||
def __init__(self):
|
||||
print('Servicer initialised')
|
||||
self.slam = slam.SlamStreamer()
|
||||
|
||||
def start_map_streaming(self, request, context):
|
||||
print('Received Map Start Streaming Request')
|
||||
if self.slam_thread is None:
|
||||
print('initialising slam_thread')
|
||||
# Don't bother creating and starting slam more than once.
|
||||
self.slam.port = request.port
|
||||
self.slam.map_pixels = request.map_size_pixels
|
||||
self.slam.map_meters = request.map_size_meters
|
||||
self.slam_thread = Process(target=self.slam.start)
|
||||
self.slam_thread.start()
|
||||
return proto.Empty()
|
||||
|
||||
def stop_streaming(self, request, context):
|
||||
if self.slam_thread is not None:
|
||||
self.slam.stop_scanning()
|
||||
self.slam_thread.join()
|
||||
self.slam = None
|
||||
return proto.Empty()
|
||||
122
car/src/slam/slam_streamer.py
Normal file
122
car/src/slam/slam_streamer.py
Normal file
@@ -0,0 +1,122 @@
|
||||
import zmq
|
||||
from breezyslam.algorithms import RMHC_SLAM
|
||||
from breezyslam.sensors import RPLidarA1 as LaserModel
|
||||
from slam.SlamController_pb2 import SlamScan, SlamLocation
|
||||
import messaging.message_factory as mf
|
||||
import messaging.messages as messages
|
||||
import tracking.devices.factory as lidar_fact
|
||||
|
||||
|
||||
# Left here as was used in the example, configure as necessary.
|
||||
# MAP_SIZE_PIXELS = 500
|
||||
# MAP_SIZE_METERS = 10
|
||||
# LIDAR_DEVICE = '/dev/ttyUSB0'
|
||||
|
||||
class SlamStreamer:
|
||||
can_scan = False
|
||||
|
||||
def __init__(self, map_pixels=None, map_meters=None, port=None):
|
||||
self._map_pixels = map_pixels
|
||||
self._map_meters = map_meters
|
||||
self._port = port
|
||||
|
||||
def start(self):
|
||||
'''
|
||||
Does scanning and constructs the slam map,
|
||||
and pushes to subscribers through a zmq pub socket.
|
||||
This is done on the main thread, so you'll need
|
||||
to run this method on a separate thread yourself.
|
||||
|
||||
All constructor parameters must be set prior
|
||||
to calling this method, and changing those values after
|
||||
calling this method will have no effect.
|
||||
'''
|
||||
self.can_scan = True
|
||||
print('Starting to stream')
|
||||
self._mFactory = mf.getZmqPubSubStreamer(self._port)
|
||||
|
||||
print('Started and bound zmq socket.')
|
||||
|
||||
# Adapted from BreezySLAM rpslam example.
|
||||
# Connect to Lidar unit. For some reason it likes to be done twice, otherwise it errors out...
|
||||
lidar = lidar_fact.get_lidar()
|
||||
lidar = lidar_fact.get_lidar()
|
||||
|
||||
print('Initialised lidar')
|
||||
|
||||
# Create an RMHC SLAM object with a laser model and optional robot model
|
||||
slam = RMHC_SLAM(LaserModel(), self._map_pixels, self._map_meters)
|
||||
|
||||
print('initialised slam')
|
||||
|
||||
# Initialize empty map
|
||||
mapbytes = bytearray(self.map_pixels * self.map_pixels)
|
||||
|
||||
print('Initialised byte []')
|
||||
|
||||
# Create an iterator to collect scan data from the RPLidar
|
||||
iterator = lidar.iter_scans()
|
||||
|
||||
print('Scanning')
|
||||
|
||||
while self.can_scan:
|
||||
# Extract (quality, angle, distance) triples from current scan
|
||||
items = [item for item in next(iterator)]
|
||||
|
||||
# Extract distances and angles from triples
|
||||
distances = [item[2] for item in items]
|
||||
angles = [item[1] for item in items]
|
||||
print('Updating map')
|
||||
# Update SLAM with current Lidar scan and scan angles
|
||||
slam.update(distances, scan_angles_degrees=angles)
|
||||
print('Map updated')
|
||||
slam.getmap(mapbytes)
|
||||
self._push_map(mapbytes, slam.getpos())
|
||||
|
||||
def _push_map(self, mapbytes, location):
|
||||
'''
|
||||
Pushes a scan over zmq using protocol buffers.
|
||||
map should be the result of slam.getmap.
|
||||
location should be a tuple, the result of slam.getpos()
|
||||
'''
|
||||
protoScan = messages.ProtoMessage(message=SlamScan(map=bytes(mapbytes),
|
||||
location=SlamLocation(x=location[0], y=location[1], theta=location[2])))
|
||||
print('Sending map')
|
||||
self._mFactory.send_message_topic(
|
||||
'slam_map', protoScan)
|
||||
|
||||
def stop_scanning(self):
|
||||
self.can_scan = False
|
||||
|
||||
# Properties
|
||||
@property
|
||||
def map_pixels(self):
|
||||
return self._map_pixels
|
||||
|
||||
@map_pixels.setter
|
||||
def map_pixels(self, value):
|
||||
self._map_pixels = value
|
||||
|
||||
@property
|
||||
def map_meters(self):
|
||||
return self._map_meters
|
||||
|
||||
@map_meters.setter
|
||||
def map_meters(self, value):
|
||||
self._map_meters = value
|
||||
|
||||
@property
|
||||
def lidar_connection(self):
|
||||
return self._lidar_connection
|
||||
|
||||
@lidar_connection.setter
|
||||
def lidar_connection(self, value):
|
||||
self._lidar_connection = value
|
||||
|
||||
@property
|
||||
def port(self):
|
||||
return self._port
|
||||
|
||||
@port.setter
|
||||
def port(self, value):
|
||||
self._port = value
|
||||
28
car/src/slam/zmq_pair_testing/pair.py
Normal file
28
car/src/slam/zmq_pair_testing/pair.py
Normal file
@@ -0,0 +1,28 @@
|
||||
import zmq
|
||||
from threading import Thread
|
||||
import time
|
||||
|
||||
context = zmq.Context.instance()
|
||||
|
||||
def client(context):
|
||||
print('in thread')
|
||||
socket = context.socket(zmq.SUB)
|
||||
print('created socket')
|
||||
socket.connect('tcp://localhost:5050')
|
||||
socket.subscribe(b'slam_map')
|
||||
while True:
|
||||
print(socket.recv())
|
||||
|
||||
def server(context):
|
||||
print('in thread')
|
||||
socket = context.socket(zmq.PUB)
|
||||
print('created socket')
|
||||
socket.bind('tcp://*:5050')
|
||||
while True:
|
||||
socket.send_multipart([b'slam_map', b'Hi'])
|
||||
time.sleep(1)
|
||||
|
||||
# client_thread = Thread(target=client, args=[context])
|
||||
server_thread = Thread(target=server, args=[context])
|
||||
server_thread.start()
|
||||
# client_thread.start()
|
||||
0
car/src/tracking/__init__.py
Normal file
0
car/src/tracking/__init__.py
Normal file
212
car/src/tracking/algorithms.py
Normal file
212
car/src/tracking/algorithms.py
Normal file
@@ -0,0 +1,212 @@
|
||||
import math
|
||||
|
||||
|
||||
class Group:
|
||||
|
||||
def __init__(self, number, points=[]):
|
||||
self._points = points
|
||||
self._number = number
|
||||
self._minX = None
|
||||
self._maxX = None
|
||||
self._minY = None
|
||||
self._maxY = None
|
||||
|
||||
def add_point(self, point):
|
||||
self._points.append(point)
|
||||
self._update_min_max(point)
|
||||
|
||||
def get_points(self):
|
||||
return self._points
|
||||
|
||||
@property
|
||||
def number(self):
|
||||
return self._number
|
||||
|
||||
@number.setter
|
||||
def number(self, number):
|
||||
self._number = number
|
||||
|
||||
def _update_min_max(self, new_point):
|
||||
"""
|
||||
Updates the in and max points for this group.
|
||||
This is to determine when assigning groups whether the
|
||||
same group is selected.
|
||||
"""
|
||||
converted_point = convert_lidar_to_cartesian(new_point)
|
||||
|
||||
if self._minX is None or self._minX > converted_point[0]:
|
||||
self._minX = converted_point[0]
|
||||
|
||||
if self._maxX is None or self._maxX < converted_point[0]:
|
||||
self._maxX = converted_point[0]
|
||||
|
||||
if self._minY is None or self._minY > converted_point[1]:
|
||||
self._minY = converted_point[1]
|
||||
|
||||
if self._maxY is None or self._maxY < converted_point[1]:
|
||||
self._maxY = converted_point[1]
|
||||
|
||||
def get_minX(self):
|
||||
return self._minY
|
||||
|
||||
def get_maxX(self):
|
||||
return self._maxY
|
||||
|
||||
def get_minY(self):
|
||||
return self._minY
|
||||
|
||||
def get_maxY(self):
|
||||
return self._maxY
|
||||
|
||||
|
||||
def convert_lidar_to_cartesian(new_point):
|
||||
x = new_point[2] * math.sin(new_point[1])
|
||||
y = new_point[2] * math.cos(new_point[1])
|
||||
return (x, y)
|
||||
|
||||
|
||||
def convert_cartesian_to_lidar(x, y):
|
||||
"""
|
||||
Converts a point on the grid (with car as the origin) to a lidar tuple (distance, angle)
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x
|
||||
Horizontal component of point to convert.
|
||||
|
||||
y
|
||||
Vertical component of point to convert.
|
||||
|
||||
Returns
|
||||
-------
|
||||
converted
|
||||
A tuple (distance, angle) that represents the point. Angle is in degrees.
|
||||
"""
|
||||
# Angle depends on x/y position.
|
||||
# if x is positive and y is positive, then angle = tan-1(y/x)
|
||||
# if x is positive and y is negative, then angle = 360 + tan-1(y/x)
|
||||
# if x is negative and y is positive, then angle = 180 + tan-1(y/x)
|
||||
# if x is negative and y is negative, then angle = 180 + tan-1(y/x)
|
||||
return (math.sqrt(x ** 2 + y ** 2), math.degrees(math.atan(y/x)) + (180 if x < 0 else 270 if y < 0 else 0))
|
||||
|
||||
|
||||
def calc_groups(scan):
|
||||
"""
|
||||
Calculates groups of points from a lidar scan. The scan should
|
||||
already be sorted.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
||||
scan: Iterable
|
||||
The lidar scan data to get groups of.
|
||||
Should be of format: (quality, angle, distance)
|
||||
|
||||
Returns
|
||||
-------
|
||||
list
|
||||
List of groups that were found.
|
||||
"""
|
||||
prevPoint = None
|
||||
currentGroup = None
|
||||
allGroups = []
|
||||
currentGroupNumber = 0
|
||||
|
||||
# assume the list is already sorted.
|
||||
for point in scan:
|
||||
if prevPoint is None:
|
||||
prevPoint = point
|
||||
continue
|
||||
|
||||
# Distances are in mm.
|
||||
# within 1cm makes a group. Will need to play around with this.
|
||||
if (point[2] - prevPoint[2]) ** 2 < 10 ** 2:
|
||||
if currentGroup is None:
|
||||
currentGroup = Group(currentGroupNumber)
|
||||
allGroups.append(currentGroup)
|
||||
currentGroup.add_point(point)
|
||||
else:
|
||||
if currentGroup is not None:
|
||||
currentGroupNumber += 1
|
||||
currentGroup = None
|
||||
|
||||
prevPoint = point
|
||||
|
||||
return allGroups
|
||||
|
||||
|
||||
def find_centre(group):
|
||||
"""
|
||||
Gets a tuple (x,y) of the centre of the group.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
group: Group
|
||||
A group of points to find the centre of.
|
||||
|
||||
Returns
|
||||
-------
|
||||
tuple (x,y)
|
||||
The centre in the form of a tuple (x,y)
|
||||
"""
|
||||
return ((group.get_maxX() + group.get_minX()) / 2, (group.get_maxY() + group.get_minY()) / 2)
|
||||
|
||||
|
||||
def assign_groups(prev_groups, new_groups):
|
||||
"""
|
||||
Assigns group numbers to a new scan based on the groups of an old scan.
|
||||
"""
|
||||
for group in prev_groups:
|
||||
old_centre = find_centre(group)
|
||||
for new_group in new_groups:
|
||||
new_centre = find_centre(new_group)
|
||||
# They are considered the same if the new group and old group centres are within 5cm.
|
||||
if ((new_centre[0] - old_centre[0]) ** 2 + (new_centre[1] - old_centre[1]) ** 2) < 50 ** 2:
|
||||
new_group.number = group.number
|
||||
|
||||
return new_groups
|
||||
|
||||
|
||||
def updateCarVelocity(oldGroup, newGroup):
|
||||
"""
|
||||
Return a tuple (DistanceChange, AngleChange) indicating how the tracked groups have changed, which can
|
||||
be used to then update the steering/throttle of the car (or other vehicle that
|
||||
may be used)
|
||||
|
||||
Parameters
|
||||
----------
|
||||
oldGroup: Group
|
||||
The positioning of points for the group in the last scan.
|
||||
|
||||
newGroup: Group
|
||||
The positioning of points for the group in the latest scan.
|
||||
|
||||
Returns
|
||||
-------
|
||||
tuple (DistanceChange, AngleChange)
|
||||
A tuple containing how the groups' centres changed in the form (distance,angle)
|
||||
"""
|
||||
old_polar = convert_cartesian_to_lidar(*find_centre(oldGroup))
|
||||
new_centre = convert_cartesian_to_lidar(*find_centre(newGroup))
|
||||
return (new_centre[0] - old_polar[0], new_centre[1] - old_polar[1])
|
||||
|
||||
|
||||
def dualServoChange(newCentre, changeTuple):
|
||||
"""
|
||||
Gets a tuple (throttleChange, steeringChange) indicating the change that should be applied to the current
|
||||
throttle/steering of an rc car that uses dual servos.
|
||||
|
||||
Parameters
|
||||
---------
|
||||
newCentre
|
||||
Tuple (distance, angle) of the new centre of the tracked group.
|
||||
|
||||
changeTuple
|
||||
Tuple (distanceChange, angleChange) from the old centre to the new centre.
|
||||
|
||||
Returns
|
||||
-------
|
||||
tuple
|
||||
Tuple of (throttleChange, steeringChange) to apply to the 2 servos.
|
||||
"""
|
||||
return ((changeTuple[0] / 3) - (newCentre[0] / 4) + 1, 0)
|
||||
BIN
car/src/tracking/all_scans.txt
Normal file
BIN
car/src/tracking/all_scans.txt
Normal file
Binary file not shown.
43
car/src/tracking/animate.py
Executable file
43
car/src/tracking/animate.py
Executable file
@@ -0,0 +1,43 @@
|
||||
#!/usr/bin/env python3
|
||||
'''Animates distances and measurment quality'''
|
||||
from tracking.mock_lidar import MockLidar
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import matplotlib.animation as animation
|
||||
import tracking.lidar_loader as loader
|
||||
|
||||
|
||||
PORT_NAME = '/dev/ttyUSB0'
|
||||
DMAX = 4000
|
||||
IMIN = 0
|
||||
IMAX = 50
|
||||
|
||||
|
||||
def update_line(num, iterator, line):
|
||||
scan = next(iterator)
|
||||
offsets = np.array([(np.radians(meas[1]), meas[2]) for meas in scan])
|
||||
line.set_offsets(offsets)
|
||||
intens = np.array([meas[0] for meas in scan])
|
||||
line.set_array(intens)
|
||||
return line,
|
||||
|
||||
|
||||
def run():
|
||||
lidar = MockLidar(loader.load_scans_bytes_file("tracking/out.pickle"))
|
||||
fig = plt.figure()
|
||||
ax = plt.subplot(111, projection='polar')
|
||||
line = ax.scatter([0, 0], [0, 0], s=5, c=[IMIN, IMAX],
|
||||
cmap=plt.cm.Greys_r, lw=0)
|
||||
ax.set_rmax(DMAX)
|
||||
ax.grid(True)
|
||||
|
||||
iterator = lidar.iter_scans()
|
||||
ani = animation.FuncAnimation(fig, update_line,
|
||||
fargs=(iterator, line), interval=50)
|
||||
plt.show()
|
||||
lidar.stop()
|
||||
lidar.disconnect()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
run()
|
||||
55
car/src/tracking/animate_alg.py
Normal file
55
car/src/tracking/animate_alg.py
Normal file
@@ -0,0 +1,55 @@
|
||||
"""
|
||||
Animates distances and angle of lidar
|
||||
Uses model-free algorithms to track grouping of points (objects/groups)
|
||||
"""
|
||||
from tracking.mock_lidar import MockLidar
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import matplotlib.animation as animation
|
||||
import tracking.lidar_loader as loader
|
||||
import tracking.algorithms as alg
|
||||
|
||||
|
||||
PORT_NAME = '/dev/ttyUSB0'
|
||||
DMAX = 4000
|
||||
IMIN = 0
|
||||
IMAX = 50
|
||||
|
||||
def update_line(num, iterator, line, prev_groups):
|
||||
scan = next(iterator)
|
||||
# Now update the groups, and then update the maps with different colours for different groups.
|
||||
if(prev_groups.groups is None):
|
||||
prev_groups = alg.calc_groups(scan)
|
||||
groups = alg.assign_groups(prev_groups, alg.calc_groups(scan))
|
||||
offsets = np.array([(np.radians(meas[1]), meas[2]) for meas in scan])
|
||||
line.set_offsets(offsets)
|
||||
intens = np.array([meas[0] for meas in scan])
|
||||
line.set_array(intens)
|
||||
# Set the colour matrix: Just set the colours to 2 * np.pi * group number (for every group number)
|
||||
# line.set_color()
|
||||
return line,
|
||||
|
||||
class Bunch:
|
||||
def __init__(self, **kwds):
|
||||
self.__dict__.update(kwds)
|
||||
|
||||
|
||||
def run():
|
||||
lidar = MockLidar(loader.load_scans_bytes_file("tracking/out.pickle"))
|
||||
fig = plt.figure()
|
||||
ax = plt.subplot(111, projection='polar')
|
||||
line = ax.scatter([0, 0], [0, 0], s=5, c=[IMIN, IMAX],
|
||||
cmap=plt.cm.Greys_r, lw=0)
|
||||
ax.set_rmax(DMAX)
|
||||
ax.grid(True)
|
||||
prev_groups = Bunch(groups=None)
|
||||
iterator = lidar.iter_scans()
|
||||
ani = animation.FuncAnimation(fig, update_line,
|
||||
fargs=(iterator, line, prev_groups), interval=50)
|
||||
plt.show()
|
||||
lidar.stop()
|
||||
lidar.disconnect()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
run()
|
||||
0
car/src/tracking/devices/__init__.py
Normal file
0
car/src/tracking/devices/__init__.py
Normal file
13
car/src/tracking/devices/factory.py
Normal file
13
car/src/tracking/devices/factory.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from tracking.devices.mock_lidar import MockLidar
|
||||
from rplidar import RPLidar
|
||||
import tracking.lidar_loader as loader
|
||||
|
||||
connection = "TEST"
|
||||
# connection = '/dev/ttyUSB0'
|
||||
|
||||
def get_lidar():
|
||||
# Need a way to configure this, maybe with environment variables
|
||||
if connection == 'TEST':
|
||||
return MockLidar(loader.load_scans_bytes_file("tracking/out.pickle"))
|
||||
else:
|
||||
return RPLidar(connection)
|
||||
43
car/src/tracking/devices/mock_lidar.py
Normal file
43
car/src/tracking/devices/mock_lidar.py
Normal file
@@ -0,0 +1,43 @@
|
||||
"""
|
||||
This module contains a MockLidar class, for use in place of RPLidar.
|
||||
Importantly, it implements iter_scans, so it can be substituted for RPLidar
|
||||
in the lidar_cache for testing (or anywhere else the rplidar may be used)
|
||||
"""
|
||||
|
||||
import tracking.lidar_loader as loader
|
||||
|
||||
|
||||
class MockLidar:
|
||||
|
||||
def __init__(self, scan_iter=None):
|
||||
"""
|
||||
Create mock lidar with an iterator that can be used as fake (or reused) scan data.
|
||||
|
||||
Examples
|
||||
--------
|
||||
lidar = MockLidar(scans)
|
||||
first_scan = next(lidar.iter_scans(measurements=100))
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
||||
scan_iter: Iterable
|
||||
An iterator that will generate/provide the fake/old scan data.
|
||||
|
||||
"""
|
||||
self._iter = scan_iter
|
||||
|
||||
def iter_scans(self, min_len=100):
|
||||
return iter(self._iter)
|
||||
|
||||
def get_health(self):
|
||||
return "Mock Lidar has scans" if self._iter is not None else "Mock lidar won't work properly!"
|
||||
|
||||
def get_info(self):
|
||||
return self.get_health()
|
||||
|
||||
def stop(self):
|
||||
pass
|
||||
|
||||
def disconnect(self):
|
||||
pass
|
||||
84
car/src/tracking/lidar_cache.py
Normal file
84
car/src/tracking/lidar_cache.py
Normal file
@@ -0,0 +1,84 @@
|
||||
from threading import Thread
|
||||
from tracking import algorithms
|
||||
import tracking.lidar_tracker_pb2 as tracker_pb
|
||||
import zmq
|
||||
|
||||
|
||||
class LidarCache():
|
||||
"""
|
||||
A class that retrieves scans from the lidar,
|
||||
runs grouping algorithms between scans and
|
||||
keeps a copy of the group data.
|
||||
"""
|
||||
|
||||
def __init__(self, lidar, measurements=100):
|
||||
self.lidar = lidar
|
||||
self.measurements = measurements
|
||||
print('Info: ' + self.lidar.get_info())
|
||||
print('Health: ' + self.lidar.get_health())
|
||||
self.run = True
|
||||
self.tracking_group_number = -1
|
||||
self.currentGroups = None
|
||||
self._group_listeners = []
|
||||
|
||||
def start_cache(self):
|
||||
self.thread = Thread(target=self.do_scanning)
|
||||
self.thread.start()
|
||||
|
||||
def do_scanning(self):
|
||||
"""Performs scans whilst cache is running, and will pass calculated groups data to the sender.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
listener:
|
||||
Any object that includes the onGroupsChanged method.
|
||||
|
||||
"""
|
||||
|
||||
# Batch over scans, so we don't need to do our own batching to determine groups
|
||||
# TODO: Implement custom batching, as iter_scans can be unreliable
|
||||
for scan in self.lidar.iter_scans(min_len=self.measurements):
|
||||
print('Got %d measurments' % (len(scan)))
|
||||
if len(scan) < self.measurements:
|
||||
# Poor scan, likely since it was the first scan.
|
||||
continue
|
||||
|
||||
if not self.run:
|
||||
break
|
||||
|
||||
# Now process the groups.
|
||||
if self.currentGroups is not None:
|
||||
self.currentGroups = algorithms.assign_groups(
|
||||
self.currentGroups, algorithms.calc_groups(scan))
|
||||
else:
|
||||
self.currentGroups = algorithms.calc_groups(scan)
|
||||
|
||||
self.fireGroupsChanged()
|
||||
|
||||
def fireGroupsChanged(self):
|
||||
# Send the updated groups to 0MQ socket.
|
||||
# Rename this to be a generic listener method, rather than an explicit 'send' (even though it can be treated as such already)
|
||||
pointScan = tracker_pb.PointScan()
|
||||
for group in self.currentGroups:
|
||||
for point in group.get_points():
|
||||
pointScan.points.append(tracker_pb.Point(
|
||||
angle=point[1], distance=point[2], group_number=group.number))
|
||||
|
||||
for listener in self._group_listeners:
|
||||
listener.onGroupsChanged(pointScan)
|
||||
|
||||
def add_groups_changed_listener(self, listener):
|
||||
"""
|
||||
Add a listener for a change in scans. THis will provide a tuple with the new group
|
||||
scans, which can then be sent off to a network listener for display, or to update the
|
||||
vehicle with a new velocity.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
listener
|
||||
An object that implements the onGroupsChanged(message) method.
|
||||
"""
|
||||
self._group_listeners.append(listener)
|
||||
|
||||
def stop_scanning(self):
|
||||
self.run = False
|
||||
26
car/src/tracking/lidar_loader.py
Normal file
26
car/src/tracking/lidar_loader.py
Normal file
@@ -0,0 +1,26 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
This module is a utility to load and save lidar
|
||||
scans to disk.
|
||||
As such, it is useful for testing, to create real lidar
|
||||
data that can be reused later, without needing to connect the lidar.
|
||||
"""
|
||||
|
||||
from rplidar import RPLidar
|
||||
import pickle
|
||||
|
||||
|
||||
def get_scans(num_scans, device='/dev/ttyUSB0', measurements_per_scan=100):
|
||||
lidar = RPLidar(device)
|
||||
scans = lidar.iter_scans(measurements_per_scan)
|
||||
return [next(scans) for i in range(0, num_scans)]
|
||||
|
||||
|
||||
def save_scans_bytes(scans, filename='out.pickle'):
|
||||
with open(filename, 'wb') as f:
|
||||
pickle.dump(scans, f)
|
||||
|
||||
|
||||
def load_scans_bytes_file(filename):
|
||||
with open(filename, 'rb') as f:
|
||||
return pickle.load(f)
|
||||
44
car/src/tracking/lidar_servicer.py
Normal file
44
car/src/tracking/lidar_servicer.py
Normal file
@@ -0,0 +1,44 @@
|
||||
import tracking.lidar_tracker_pb2 as lidar_tracker_pb2
|
||||
from tracking.lidar_tracker_pb2_grpc import PersonTrackingServicer
|
||||
from tracking.lidar_cache import LidarCache
|
||||
from multiprocessing import Process
|
||||
import messaging.message_factory as mf
|
||||
import tracking.devices.factory as lidar_factory
|
||||
|
||||
from messaging import messages
|
||||
import tracking.algorithms as alg
|
||||
|
||||
class LidarServicer(PersonTrackingServicer):
|
||||
|
||||
def __init__(self, vehicle=None):
|
||||
# TODO: Put the rplidar creation in a factory or something, to make it possible to test this servicer.
|
||||
# Also, it would allow creating the service without the lidar being connected.
|
||||
self.cache = LidarCache(lidar_factory.get_lidar(), measurements=100)
|
||||
self.cache.add_groups_changed_listener(self)
|
||||
self._mFactory = None
|
||||
self._port = None
|
||||
self._vehicle = vehicle
|
||||
self._tracked_group = None
|
||||
|
||||
def set_tracking_group(self, request, context):
|
||||
self._tracked_group = request.value
|
||||
|
||||
def stop_tracking(self, request, context):
|
||||
self.cache.stop_scanning()
|
||||
|
||||
def start_tracking(self, request, context):
|
||||
"""Starts the lidar cache, streaming on the provided port."""
|
||||
self._port = request.value
|
||||
self.cache.start_cache()
|
||||
|
||||
def onGroupsChanged(self, message):
|
||||
if self._mFactory is None:
|
||||
# Create the zmq socket in the thread that it will be used, just to be safe.
|
||||
self._mFactory = mf.getZmqPubSubStreamer(self._port)
|
||||
self._mFactory.send_message_topic("lidar_map", messages.ProtoMessage(message=message.SerializeToString()))
|
||||
|
||||
if self._tracked_group is not None and self._vehicle is not None:
|
||||
# Update vehicle to correctly follow the tracked group.
|
||||
# Leave for now, need to work out exactly how this will change.
|
||||
# alg.dualServoChange(alg.find_centre())
|
||||
pass
|
||||
5
car/src/tracking/lidar_tester.py
Normal file
5
car/src/tracking/lidar_tester.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from tracking.lidar_cache import LidarCache
|
||||
import Messaging.message_factory as mf
|
||||
|
||||
|
||||
|
||||
BIN
car/src/tracking/out.pickle
Normal file
BIN
car/src/tracking/out.pickle
Normal file
Binary file not shown.
4
car/src/tracking/readme.txt
Normal file
4
car/src/tracking/readme.txt
Normal file
@@ -0,0 +1,4 @@
|
||||
To load the lidar dummy scans in all_scans.txt,
|
||||
use python pickle:
|
||||
with open('path/to/all_scans.txt', 'rb') as fp:
|
||||
all_scans = pickle.load(fp)
|
||||
Reference in New Issue
Block a user