pref: 注册时直接传入任务类

feat: 分任务设置检测计数值
This commit is contained in:
bmy
2024-05-29 21:23:05 +08:00
parent e3734c5ead
commit 49c0499f24
5 changed files with 539 additions and 117 deletions

View File

@@ -3,8 +3,8 @@ logger_filename = "log/file_{time}.log"
logger_format = "{time} {level} {message}" logger_format = "{time} {level} {message}"
[task] [task]
GetBlock_enable = true # 人员施救使能 GetBlock_enable = false # 人员施救使能
PutBlock_enable = true # 紧急转移使能 PutBlock_enable = false # 紧急转移使能
GetBBall_enable = true # 整装上阵使能 GetBBall_enable = true # 整装上阵使能
UpTower_enable = true # 通信抢修使能 UpTower_enable = true # 通信抢修使能
GetRBall_enable = true # 高空排险使能 GetRBall_enable = true # 高空排险使能
@@ -12,3 +12,15 @@ PutBBall_enable = true # 派发物资使能
PutHanoi_enable = true # 物资盘点使能 PutHanoi_enable = true # 物资盘点使能
MoveArea_enable = true # 应急避险使能 MoveArea_enable = true # 应急避险使能
KickAss_enable = true # 扫黑除暴使能 KickAss_enable = true # 扫黑除暴使能
[find_counts]
GetBlock_counts = 5 # 人员施救使能
PutBlock_counts = 5 # 紧急转移使能
GetBBall_counts = 5 # 整装上阵使能
UpTower_counts = 5 # 通信抢修使能
GetRBall_counts = 5 # 高空排险使能
PutBBall_counts = 5 # 派发物资使能
PutHanoi1_counts = 20 # 物资盘点使能
PutHanoi2_counts = 5 # 物资盘点使能
MoveArea_counts = 5 # 应急避险使能
KickAss_counts = 5 # 扫黑除暴使能

35
main.py
View File

@@ -4,6 +4,10 @@ import threading
from loguru import logger from loguru import logger
import subtask as sb import subtask as sb
import majtask as mj import majtask as mj
from by_cmd_py import by_cmd_py
import time
cmd_py_obj = by_cmd_py()
sb.import_obj(cmd_py_obj)
# 读取配置 # 读取配置
cfg_main = toml.load('cfg_main.toml') cfg_main = toml.load('cfg_main.toml')
@@ -12,17 +16,18 @@ cfg_main = toml.load('cfg_main.toml')
logger.add(cfg_main['debug']['logger_filename'], format=cfg_main['debug']['logger_format'], retention = 5, level="INFO") logger.add(cfg_main['debug']['logger_filename'], format=cfg_main['debug']['logger_format'], retention = 5, level="INFO")
# 向任务队列添加任务 # 向任务队列添加任务
# TODO 任务关闭相关联
task_queue = queue.Queue() task_queue = queue.Queue()
task_queue.put(sb.task(sb.get_block.exec, sb.get_block.find, cfg_main['task']['GetBlock_enable'])) task_queue.put(sb.task(sb.get_block, cfg_main['find_counts']['GetBlock_counts'], cfg_main['task']['GetBlock_enable']))
task_queue.put(sb.task(sb.put_block.exec, sb.put_block.find, cfg_main['task']['PutBlock_enable'])) task_queue.put(sb.task(sb.put_block, cfg_main['find_counts']['PutBlock_counts'], cfg_main['task']['PutBlock_enable']))
task_queue.put(sb.task(sb.get_bball.exec, sb.get_bball.find, cfg_main['task']['GetBBall_enable'])) task_queue.put(sb.task(sb.get_bball, cfg_main['find_counts']['GetBBall_counts'], cfg_main['task']['GetBBall_enable']))
task_queue.put(sb.task(sb.up_tower.exec, sb.up_tower.find, cfg_main['task']['UpTower_enable'])) task_queue.put(sb.task(sb.up_tower, cfg_main['find_counts']['UpTower_counts'], cfg_main['task']['UpTower_enable']))
task_queue.put(sb.task(sb.get_rball.exec, sb.get_rball.find, cfg_main['task']['GetRBall_enable'])) task_queue.put(sb.task(sb.get_rball, cfg_main['find_counts']['GetRBall_counts'], cfg_main['task']['GetRBall_enable']))
task_queue.put(sb.task(sb.put_bball.exec, sb.put_bball.find, cfg_main['task']['PutBBall_enable'])) task_queue.put(sb.task(sb.put_bball, cfg_main['find_counts']['PutBBall_counts'], cfg_main['task']['PutBBall_enable']))
task_queue.put(sb.task(sb.put_hanoi1.exec, sb.put_hanoi1.find, True)) # 无论是否进行任务,检测标识并转向都是必须进行的 task_queue.put(sb.task(sb.put_hanoi1, cfg_main['find_counts']['PutHanoi1_counts'], True)) # 无论是否进行任务,检测标识并转向都是必须进行的
task_queue.put(sb.task(sb.put_hanoi2.exec, sb.put_hanoi2.find, cfg_main['task']['PutHanoi_enable'])) task_queue.put(sb.task(sb.put_hanoi2, cfg_main['find_counts']['PutHanoi2_counts'], cfg_main['task']['PutHanoi_enable']))
task_queue.put(sb.task(sb.move_area.exec, sb.move_area.find, cfg_main['task']['MoveArea_enable'])) task_queue.put(sb.task(sb.move_area, cfg_main['find_counts']['MoveArea_counts'], cfg_main['task']['MoveArea_enable']))
task_queue.put(sb.task(sb.kick_ass.exec, sb.kick_ass.find, cfg_main['task']['KickAss_enable'])) task_queue.put(sb.task(sb.kick_ass, cfg_main['find_counts']['KickAss_counts'], cfg_main['task']['KickAss_enable']))
# 将任务队列传入调度模块中 # 将任务队列传入调度模块中
task_queuem_t = sb.task_queuem(task_queue) task_queuem_t = sb.task_queuem(task_queue)
@@ -35,12 +40,18 @@ def worker_thread():
# 启动工作线程 # 启动工作线程
worker = threading.Thread(target=worker_thread, daemon=True) worker = threading.Thread(target=worker_thread, daemon=True)
worker.start() worker.start()
if (cmd_py_obj.send_angle_camera(180) == -1):
cmd_py_obj.send_angle_camera(180)
time.sleep(2)
# cmd_py_obj.send_speed_x(5)
# cmd_py_obj.send_position_axis_z(10, 100)
# 创建主任务 # 创建主任务
main_task_t = mj.main_task(None) # TODO 初始化时传入 zmq socket 对象 main_task_t = mj.main_task(cmd_py_obj) # 初始化时传入 zmq socket 对象
# 主线程仅在子线程搜索 (SEARCHING) 和 空闲 (IDLE) 状态下进行操作 # 主线程仅在子线程搜索 (SEARCHING) 和 空闲 (IDLE) 状态下进行操作
while task_queuem_t.busy is True: # while task_queuem_t.busy is True:
while True:
if task_queuem_t.status is sb.task_queuem_status.EXECUTING: if task_queuem_t.status is sb.task_queuem_status.EXECUTING:
pass pass
else: else:

View File

@@ -1,5 +1,7 @@
from simple_pid import PID from simple_pid import PID
# import queue import zmq
import time
from loguru import logger
class PidWrap: class PidWrap:
def __init__(self, kp, ki, kd, setpoint=0, output_limits=1): def __init__(self, kp, ki, kd, setpoint=0, output_limits=1):
@@ -14,8 +16,10 @@ class PidWrap:
return self.pid_t(val_in) return self.pid_t(val_in)
class main_task(): class main_task():
def __init__(self,socket): def __init__(self,by_cmd):
self.lane_socket = socket self.context = zmq.Context()
self.socket = self.context.socket(zmq.REQ)
self.socket.connect("tcp://localhost:6666")
# 赛道回归相关 # 赛道回归相关
self.x = 0 self.x = 0
@@ -24,7 +28,7 @@ class main_task():
self.lane_error = 0 self.lane_error = 0
# 车控制对象初始化 # 车控制对象初始化
# self.by_cmd = by_cmd_py() self.by_cmd = by_cmd
# 转向 pid # 转向 pid
self.pid1 = PidWrap(0.7, 0, 0,output_limits=40) self.pid1 = PidWrap(0.7, 0, 0,output_limits=40)
@@ -32,9 +36,8 @@ class main_task():
def parse_data(self,data): def parse_data(self,data):
if data.get('code') == 0: if data.get('code') == 0:
if data.get('type') == 'infer': self.x += data.get('data')[0]
self.x += data.get('data')[0][0] self.y += data.get('data')[1]
self.y += data.get('data')[0][1]
self.error_counts += 1 self.error_counts += 1
else: else:
@@ -52,6 +55,7 @@ class main_task():
def lane_task(self): def lane_task(self):
# TODO 巡航参数从配置文件中读取 # TODO 巡航参数从配置文件中读取
time.sleep(0.002)
if self.error_counts > 2: if self.error_counts > 2:
self.x = self.x / 3 self.x = self.x / 3
self.y = self.y / 3 self.y = self.y / 3
@@ -59,12 +63,27 @@ class main_task():
self.error_counts = 0 self.error_counts = 0
self.x = 0 self.x = 0
self.y = 0 self.y = 0
if self.lane_error > 30: error_abs = abs(self.lane_error)
if error_abs < 10:
self.pid1.set(0.7, 0, 0) self.pid1.set(0.7, 0, 0)
self.by_cmd.send_speed_x(12)
elif error_abs > 45:
self.by_cmd.send_speed_x(6)
self.pid1.set(1.8, 0, 0)
elif error_abs > 35:
self.by_cmd.send_speed_x(8)
self.pid1.set(1.5, 0, 0)
elif error_abs > 25:
self.by_cmd.send_speed_x(10)
self.pid1.set(1, 0, 0)
else: else:
self.pid1.set(0.5, 0, 0) self.pid1.set(0.8, 0, 0)
self.by_cmd.send_speed_x(11)
# TODO 待引入控制接口 # TODO 待引入控制接口
# self.by_cmd.send_speed_x(7)
pid_out = self.pid1.get(self.lane_error) pid_out = self.pid1.get(self.lane_error*0.65)
# self.by_cmd.send_speed_omega(pid_out) self.by_cmd.send_speed_omega(pid_out)
# self.lane_socket.send_string("infer") self.socket.send_string("")
resp = self.socket.recv_pyobj()
# logger.info(resp)
self.parse_data(resp)

View File

@@ -3,24 +3,52 @@ from loguru import logger
from utils import label_filter from utils import label_filter
from utils import tlabel from utils import tlabel
import toml import toml
import zmq
import time import time
context = zmq.Context()
socket = context.socket(zmq.REQ)
socket.connect("tcp://localhost:6667")
logger.info("socket init")
by_cmd = None
filter = None
def import_obj(_by_cmd):
global by_cmd
global filter
by_cmd = _by_cmd
filter = label_filter(socket)
# 任务类 # 任务类
class task: class task:
def __init__(self, func_exec, func_find, enable=True): def __init__(self, task_template, find_counts=10, enable=True):
self.enable = enable self.enable = enable
self.func_exec = func_exec self.task_t = task_template()
self.func_find = func_find
self.counts = 0
self.find_counts = find_counts
def init(self):
self.task_t.init()
def find(self): def find(self):
# 检查该任执行标志 # 检查该任执行标志
# TODO 完善该接口,是否需要单独为每种 task 编写一个函数,还是设置一个通用的过滤器(从 detection 模块过滤结果) while True:
while self.func_find() is False: # if self.func_find():
pass if self.task_t.find():
self.counts += 1
if self.counts >= self.find_counts:
break
# while self.func_find() is False:
# pass
def exec(self): def exec(self):
# 根据标志位确定是否执行该任务 # 根据标志位确定是否执行该任务
if self.enable is True: if self.enable is True:
logger.debug(f"[Task ]# Executing task") logger.debug(f"[Task ]# Executing task")
self.func_exec() # self.func_exec()
self.task_t.exec()
logger.debug(f"[Task ]# Task completed") logger.debug(f"[Task ]# Task completed")
else: else:
logger.warning(f"[Task ]# Skip task") logger.warning(f"[Task ]# Skip task")
@@ -31,7 +59,7 @@ class task_queuem_status(Enum):
SEARCHING = 1 SEARCHING = 1
EXECUTING = 2 EXECUTING = 2
# 任务队列类 非EXECUTEING 时均执行 huigui注意互斥操作 # 任务队列类 非 EXECUTEING 时均执行 huigui注意互斥操作
class task_queuem(task): class task_queuem(task):
# task_now = task(None, False) # task_now = task(None, False)
def __init__(self, queue): def __init__(self, queue):
@@ -52,9 +80,10 @@ class task_queuem(task):
# 如果当前任务没有使能,则直接转入执行状态,由任务执行函数打印未执行信息 # 如果当前任务没有使能,则直接转入执行状态,由任务执行函数打印未执行信息
if self.task_now.enable is True: if self.task_now.enable is True:
self.status = task_queuem_status.SEARCHING self.status = task_queuem_status.SEARCHING
# 如果使能该任务则执行该任务的初始化动作
self.task_now.init()
else: else:
self.status = task_queuem_status.EXECUTING self.status = task_queuem_status.EXECUTING
logger.info(f"[TaskM]# ---------------------->>>>") logger.info(f"[TaskM]# ---------------------->>>>")
# 阻塞搜索任务标志位 # 阻塞搜索任务标志位
elif self.status is task_queuem_status.SEARCHING: elif self.status is task_queuem_status.SEARCHING:
@@ -75,143 +104,367 @@ class task_queuem(task):
# 人员施救 # 人员施救
class get_block(): class get_block():
def find(): def init(self):
logger.info("人员施救初始化")
def find(self):
# 目标检测红/蓝方块 # 目标检测红/蓝方块
filter = label_filter(None)
ret1, list1 = filter.get(tlabel.RBLOCK) ret1, list1 = filter.get(tlabel.RBLOCK)
ret2, list2 = filter.get(tlabel.BBLOCK) if ret1 > 0:
logger.info("[抓方块]# find label")
if (ret1 > 0) or (ret2 > 0):
logger.info("[TASK1]# find label")
return True return True
else: else:
return False return False
def exec(): def exec(self):
for _ in range(3):
by_cmd.send_speed_x(7)
by_cmd.send_speed_omega(0)
time.sleep(0.1)
logger.info("abcd")
cfg = toml.load('cfg_subtask.toml') # 加载任务配置 cfg = toml.load('cfg_subtask.toml') # 加载任务配置
while True:
# logger.info("等待进入准确区域")
ret, error = filter.aim_near(tlabel.RBLOCK)
while not ret:
ret, error = filter.aim_near(tlabel.RBLOCK)
# logger.info(error)
if abs(error) < 5:
for _ in range(3):
by_cmd.send_speed_x(0)
time.sleep(0.2)
by_cmd.send_speed_omega(0)
break
ret, error = filter.aim_near(tlabel.RBLOCK)
while not ret:
ret, error = filter.aim_near(tlabel.RBLOCK)
time.sleep(1)
logger.error(error)
if abs(error) > 5:
logger.info("校准中")
if error > 0:
by_cmd.send_distance_x(-10, int(error*3))
else:
by_cmd.send_distance_x(10, int(-error*3))
logger.error(error)
time.sleep(1)
for _ in range(3):
by_cmd.send_speed_x(0)
time.sleep(0.2)
by_cmd.send_speed_omega(0)
time.sleep(2)
by_cmd.send_position_axis_z(10, 150)
time.sleep(5)
by_cmd.send_angle_claw_arm(127)
time.sleep(1)
by_cmd.send_position_axis_x(4, 140)
time.sleep(4)
by_cmd.send_angle_claw_arm(220)
by_cmd.send_angle_claw(90)
time.sleep(1)
by_cmd.send_distance_axis_z(10, -70)
time.sleep(3)
by_cmd.send_angle_claw(27)
by_cmd.send_distance_axis_z(10, 10)
time.sleep(2)
by_cmd.send_distance_axis_x(4, -100)
time.sleep(1)
by_cmd.send_distance_axis_z(10, -40)
time.sleep(3)
by_cmd.send_angle_claw(35)
time.sleep(1)
by_cmd.send_position_axis_z(10, 150)
time.sleep(3)
by_cmd.send_position_axis_x(2, 140)
# 抓取第二个块后 收爪
time.sleep(3)
by_cmd.send_position_axis_x(4, 0)
def nexec(self):
pass pass
# 紧急转移 # 紧急转移
class put_block(): class put_block():
def find(): def init(self):
logger.info("紧急转移初始化")
def find(self):
# 目标检测医院 # 目标检测医院
filter = label_filter(None)
ret1, list1 = filter.get(tlabel.HOSPITAL) ret1, list1 = filter.get(tlabel.HOSPITAL)
if ret1 > 0: if ret1 > 0:
return True return True
else: else:
return False return False
def exec(): def exec(self):
cfg = toml.load('cfg_subtask.toml') # 加载任务配置 cfg = toml.load('cfg_subtask.toml') # 加载任务配置
logger.info("找到医院")
for _ in range(3):
by_cmd.send_speed_x(0)
time.sleep(0.2)
by_cmd.send_speed_omega(0)
by_cmd.send_position_axis_z(10, 150)
time.sleep(3)
# TODO 切换爪子方向
by_cmd.send_position_axis_x(2, 140)
time.sleep(2)
by_cmd.send_position_axis_z(10, 170)
pass pass
# 整装上阵 # 整装上阵
class get_bball(): class get_bball():
def find(): def init(self):
by_cmd.send_position_axis_x(2, 140)
logger.info("整装上阵初始化")
time.sleep(0.5)
if (by_cmd.send_angle_camera(90) == -1):
by_cmd.send_angle_camera(90)
def find(self):
# 目标检测黄球 # 目标检测黄球
filter = label_filter(None)
ret1, list1 = filter.get(tlabel.YBALL) ret1, list1 = filter.get(tlabel.YBALL)
if ret1 > 0: if ret1 > 0:
return True return True
else: else:
return False return False
def exec(): def exec(self):
logger.info("找到黄色球")
for _ in range(3):
by_cmd.send_speed_x(7)
by_cmd.send_speed_omega(0)
time.sleep(0.1)
while True:
# logger.info("等待进入准确区域")
ret, error = filter.aim_near(tlabel.YBALL)
while not ret:
ret, error = filter.aim_near(tlabel.YBALL)
# logger.info(error)
if abs(error) < 5:
for _ in range(3):
by_cmd.send_speed_x(0)
time.sleep(0.2)
by_cmd.send_speed_omega(0)
break
ret, error = filter.aim_near(tlabel.YBALL)
while not ret:
ret, error = filter.aim_near(tlabel.YBALL)
time.sleep(1)
logger.error(error)
if abs(error) > 5:
logger.info("校准中")
if error > 0:
by_cmd.send_distance_x(-10, int(error*3))
else:
by_cmd.send_distance_x(10, int(-error*3))
logger.error(error)
time.sleep(1)
if (by_cmd.send_angle_camera(0) == -1):
by_cmd.send_angle_camera(0)
by_cmd.send_position_axis_z(20, 160)
time.sleep(2)
by_cmd.send_position_axis_x(2, 70)
time.sleep(2)
by_cmd.send_angle_claw(90)
time.sleep(0.2)
by_cmd.send_position_axis_x(2, 0)
time.sleep(2)
by_cmd.send_angle_claw(27)
time.sleep(1)
by_cmd.send_position_axis_z(20, 180)
time.sleep(1)
by_cmd.send_position_axis_x(4, 45)
time.sleep(1)
by_cmd.send_position_axis_z(20, 140)
time.sleep(3)
by_cmd.send_position_axis_x(2, 140)
time.sleep(2)
by_cmd.send_angle_claw(90)
pass pass
# 通信抢修 # 通信抢修
class up_tower(): class up_tower():
def find(): def init(self):
logger.info("通信抢修初始化")
def find(self):
# 目标检测通信塔 # 目标检测通信塔
filter = label_filter(None) ret1, list1 = filter.get(tlabel.TOWER)
ret1, list1 = filter.get(tlabel.YBALL)
if ret1 > 0: if ret1 > 0:
return True return True
else: else:
return False return False
def exec(): def exec(self):
pass logger.info("找到塔")
for _ in range(3):
by_cmd.send_speed_x(7)
by_cmd.send_speed_omega(0)
time.sleep(0.1)
while True:
# logger.info("等待进入准确区域")
ret, error = filter.aim_near(tlabel.TOWER)
while not ret:
ret, error = filter.aim_near(tlabel.TOWER)
# logger.info(error)
if abs(error) < 5:
for _ in range(3):
by_cmd.send_speed_x(0)
time.sleep(0.2)
by_cmd.send_speed_omega(0)
break
ret, error = filter.aim_near(tlabel.TOWER)
while not ret:
ret, error = filter.aim_near(tlabel.TOWER)
time.sleep(1)
logger.error(error)
if abs(error) > 5:
logger.info("校准中")
if error > 0:
by_cmd.send_distance_x(-10, int(error*3))
else:
by_cmd.send_distance_x(10, int(-error*3))
logger.error(error)
time.sleep(1)
# 高空排险 # 高空排险
class get_rball(): class get_rball():
def find(): def init(self):
logger.info("高空排险初始化")
if (by_cmd.send_angle_camera(0) == -1):
by_cmd.send_angle_camera(0)
def find(self):
# 目标检测红球 # 目标检测红球
filter = label_filter(None)
ret1, list1 = filter.get(tlabel.RBALL) ret1, list1 = filter.get(tlabel.RBALL)
if ret1 > 0: if ret1 > 0:
return True return True
else: else:
return False return False
def exec(): def exec(self):
logger.info("找到红球")
for _ in range(3):
by_cmd.send_speed_x(0)
time.sleep(0.2)
by_cmd.send_speed_omega(0)
time.sleep(1)
pass pass
# 派发物资 # 派发物资
class put_bball(): class put_bball():
def find(): def init(self):
logger.info("派发物资初始化")
if (by_cmd.send_angle_camera(90) == -1):
by_cmd.send_angle_camera(90)
def find(self):
# 目标检测通信塔 # 目标检测通信塔
filter = label_filter(None)
ret1, list1 = filter.get(tlabel.BASKET) ret1, list1 = filter.get(tlabel.BASKET)
if ret1 > 0: if ret1 > 0:
return True return True
else: else:
return False return False
def exec(): def exec(self):
pass logger.info("找到篮筐")
for _ in range(3):
by_cmd.send_speed_x(0)
time.sleep(0.2)
by_cmd.send_speed_omega(0)
time.sleep(1)
pass
direction = tlabel.RMARK
direction_left = 0
direction_right = 0
# 物资盘点 # 物资盘点
class put_hanoi1(): class put_hanoi1():
def find(): def init(self):
logger.info("物资盘点 1 初始化")
socket.send_string("2")
socket.recv()
def find(self):
global direction
global direction_left
global direction_right
# 目标检测左右转向标识 # 目标检测左右转向标识
filter = label_filter(None) # TODO 框的大小判断距离
ret1, list1 = filter.get(tlabel.MARKL) ret1, list1 = filter.get(tlabel.RMARK)
ret2, list2 = filter.get(tlabel.MARKR) ret2, list2 = filter.get(tlabel.LMARK)
if ret1:
if (ret1 > 0) or (ret2 > 0): logger.info("向右拐")
direction_right += 1
return True
elif ret2:
logger.info("向左拐")
direction_left += 1
return True return True
else:
return False return False
def exec(): def exec(self):
for _ in range(3):
by_cmd.send_speed_x(0)
time.sleep(0.2)
by_cmd.send_speed_omega(0)
time.sleep(0.2)
# if direction == tlabel.RMARK:
if direction_right > direction_left:
by_cmd.send_angle_omega(-20,500)
else:
by_cmd.send_angle_omega(20,500)
time.sleep(0.2)
if (by_cmd.send_angle_camera(180) == -1):
by_cmd.send_angle_camera(180)
time.sleep(2)
socket.send_string("1")
socket.recv()
pass pass
class put_hanoi2(): class put_hanoi2():
def find(): def init(self):
logger.info("物资盘点 2 初始化")
def find(self):
# 目标检测左右转向标识 # 目标检测左右转向标识
filter = label_filter(None) ret1, list1 = filter.get(tlabel.LPILLER)
ret1, list1 = filter.get(tlabel.LPILLAR)
if ret1 > 0: if ret1 > 0:
return True return True
else: else:
return False return False
def exec(): def exec(self):
logger.info("找到最大块")
for _ in range(3):
by_cmd.send_speed_x(0)
time.sleep(0.2)
by_cmd.send_speed_omega(0)
time.sleep(1)
pass pass
# 应急避险 # 应急避险
class move_area(): class move_area():
def find(): def init(self):
logger.info("应急避险初始化")
if (by_cmd.send_angle_camera(180) == -1):
by_cmd.send_angle_camera(180)
def find(self):
# 目标检测标志牌 # 目标检测标志牌
# TODO 如何确保在都检测标志牌的情况下,和下一个任务进行区分 # TODO 如何确保在都检测标志牌的情况下,和下一个任务进行区分
filter = label_filter(None)
ret1, list1 = filter.get(tlabel.SIGN) ret1, list1 = filter.get(tlabel.SIGN)
if ret1 > 0: if ret1 > 0:
return True return True
else: else:
return False return False
def exec(): def exec(self):
logger.info("找到标示牌")
pass pass
# 扫黑除暴 # 扫黑除暴
class kick_ass(): class kick_ass():
def find(): def init(self):
logger.info("扫黑除暴初始化")
def find(self):
# 目标检测标志牌 # 目标检测标志牌
# TODO 如何确保在都检测标志牌的情况下,和上一个任务进行区分 # TODO 如何确保在都检测标志牌的情况下,和上一个任务进行区分
filter = label_filter(None)
ret1, list1 = filter.get(tlabel.SIGN) ret1, list1 = filter.get(tlabel.SIGN)
if ret1 > 0: if ret1 > 0:
return True return True
else: else:
return False return False
def exec(): def exec(self):
logger.info("找到标示牌")
pass pass

197
utils.py
View File

@@ -1,49 +1,176 @@
from enum import Enum from enum import Enum
import numpy as np
# 根据标签修改 # 根据标签修改
# class tlabel(Enum):
# BBLOCK = 5 # 蓝色方块
# RBLOCK = 2 # 红色方块
# HOSPITAL = 3 # 医院
# BBALL = 4 # 蓝球
# YBALL = 5 # 黄球
# TOWER = 6 # 通信塔
# RBALL = 7 # 红球
# BASKET = 8 # 球筐
# MARKL = 9 # 指向标
# MARKR = 10 # 指向标
# SPILLAR = 11 # 小柱体 (红色)
# MPILLAR = 12 # 中柱体 (蓝色)
# LPILLAR = 13 # 大柱体 (红色)
# SIGN = 14 # 文字标牌
# TARGET = 15 # 目标靶
# SHELTER = 16 # 停车区
# BASE = 17 # 基地
class tlabel(Enum): class tlabel(Enum):
BBLOCK = 1 # 蓝色方块 TOWER = 0
RBLOCK = 2 # 红色方块 SIGN = 1
HOSPITAL = 3 # 医院 SHELTER = 2
BBALL = 4 # 蓝球 HOSPITAL = 3
YBALL = 5 # 黄球 BASKET = 4
TOWER = 6 # 通信塔 BASE = 5
RBALL = 7 # 红球 YBALL = 6
BASKET = 8 # 球筐 SPILLER = 7
MARKL = 9 # 指向标 RMARK = 8
MARKR = 10 # 指向标 RBLOCK = 9
SPILLAR = 11 # 小柱体 (红色) RBALL = 10
MPILLAR = 12 # 中柱体 (蓝色) MPILLER = 11
LPILLAR = 13 # 大柱体 (红色) LPILLER = 12
SIGN = 14 # 文字标牌 LMARK = 13
TARGET = 15 # 目标靶 BBLOCK = 14
SHELTER = 16 # 停车区 BBALL = 15
BASE = 17 # 基地 test_resp = {
'code': 0,
'data': np.array([
[4., 0.97192055, 26.64415, 228.26755, 170.16872, 357.6216],
[4., 0.97049206, 474.0152, 251.2854, 612.91644, 381.6831],
[5., 0.972649, 250.84174, 238.43622, 378.115, 367.34906]
])
}
test1_resp = {
'code': 0,
'data': np.array([])
}
class label_filter: class label_filter:
def __init__(self, list_src): def __init__(self, socket, threshold=0.6):
self.num = 0 self.num = 0
self.pos = [] self.pos = []
self.list = list_src # 获取目标检测输出的接口 (含标签,位置,置信度) self.socket = socket
# TODO 添加置信度阈值 self.threshold = threshold
pass
self.img_size = (320, 240)
def get_resp(self):
self.socket.send_string('')
response = self.socket.recv_pyobj()
return response
def switch_camera(self,camera_id):
if camera_id == 1 or camera_id == 2:
self.socket.send_string(f'{camera_id}')
response = self.socket.recv_pyobj()
return response
def filter_box(self,data):
if len(data) > 0:
expect_boxes = (data[:, 1] > self.threshold) & (data[:, 0] > -1)
np_boxes = data[expect_boxes, :]
results = [
[
item[0], # 'label':
item[1], # 'score':
item[2], # 'xmin':
item[3], # 'ymin':
item[4], # 'xmax':
item[5] # 'ymax':
]
for item in np_boxes
]
if len(results) > 0:
return True, np.array(results)
return False, None
def get(self, tlabel): def get(self, tlabel):
# TODO 循环查找匹配的标签值 # 循环查找匹配的标签值
# TODO 返回对应标签的个数,以及坐标列表 # 返回对应标签的个数,以及坐标列表
# TODO self.filter_box none judge
response = self.get_resp()
if response['code'] == 0:
ret, results = self.filter_box(response['data'])
if ret:
expect_boxes = (results[:, 0] == tlabel.value)
boxes = results[expect_boxes, :]
self.num = len(boxes)
self.pos = boxes[:, 2:] # [[x1 y1 x2 y2]]
return self.num, self.pos return self.num, self.pos
return 0, []
def find(self, tlabel): def find(self, tlabel):
# TODO 遍历返回的列表,有对应标签则返回 True # 遍历返回的列表,有对应标签则返回 True
response = self.get_resp()
if response['code'] == 0:
ret, results = self.filter_box(response['data'])
if ret:
expect_boxes = (results[:, 0] == tlabel.value)
boxes = results[expect_boxes, :]
if len(boxes) != 0:
return True
return False return False
def aim_left(self, tlabel): def aim_left(self, tlabel):
# TODO 如果标签存在,则返回列表中位置最靠左的目标框和中心的偏移值 # 如果标签存在,则返回列表中位置最靠左的目标框和中心的偏移值
error = 0 response = self.get_resp()
return error if response['code'] == 0:
ret, results = self.filter_box(response['data'])
if ret:
expect_boxes = (results[:, 0] == tlabel.value)
boxes = results[expect_boxes, :]
if len(boxes) == 0:
return (False, )
xmin_values = boxes[:, 2] # xmin
xmin_index = np.argmin(xmin_values)
error = (boxes[xmin_index][4] + boxes[xmin_index][2] - self.img_size[0]) / 2
return (True, error)
return (False, )
def aim_right(self, tlabel): def aim_right(self, tlabel):
# TODO 如果标签存在,则返回列表中位置最靠右的目标框和中心的偏移值 # 如果标签存在,则返回列表中位置最靠右的目标框和中心的偏移值
error = 0 response = self.get_resp()
return error if response['code'] == 0:
ret, results = self.filter_box(response['data'])
if ret:
expect_boxes = (results[:, 0] == tlabel.value)
boxes = results[expect_boxes, :]
if len(boxes) == 0:
return (False, )
xmax_values = boxes[:, 4] # xmax
xmax_index = np.argmax(xmax_values)
error = (boxes[xmax_index][4] + boxes[xmax_index][2] - self.img_size[0]) / 2
return (True, error)
return (False, )
def aim_near(self, tlabel): def aim_near(self, tlabel):
# TODO 如果标签存在,则返回列表中位置最近的目标框和中心的偏移值 # 如果标签存在,则返回列表中位置最近的目标框和中心的偏移值
error = 0 response = self.get_resp()
return error if response['code'] == 0:
ret, results = self.filter_box(response['data'])
if ret:
expect_boxes = (results[:, 0] == tlabel.value)
boxes = results[expect_boxes, :]
if len(boxes) == 0:
return (False, 0)
center_x_values = np.abs(boxes[:, 2] + boxes[:, 4] - self.img_size[0])
center_x_index = np.argmin(center_x_values)
error = (boxes[center_x_index][4] + boxes[center_x_index][2] - self.img_size[0]) / 2
return (True, error+15)
return (False, 0)
# class Calibrate:
# def __init__(self,by_cmd):
# # 车控制对象初始化
# self.by_cmd = by_cmd
# def aim(self,error):
# self.by_cmd.send_distance_x(error,)
if __name__ == '__main__':
obj = label_filter(None)
# results = obj.filter_box(resp['data'])
# expect_boxes = (results[:, 0] == tlabel.SPILLAR.value)
# np_boxes = results[expect_boxes, :]
# print(np_boxes[:, 2:])
# print(len(np_boxes))
print(obj.find(tlabel.BBALL))
print(obj.aim_left(tlabel.BBALL))
print(obj.aim_right(tlabel.BBALL))
print(obj.aim_near(tlabel.BBALL))
print(obj.get(tlabel.HOSPITAL))