film_sportowe/main.py

94 lines
2.3 KiB
Python
Raw Normal View History

2023-01-29 14:29:36 +01:00
from imageai.Detection import VideoObjectDetection
import matplotlib.pyplot as plt
vid_obj_detect = VideoObjectDetection()
vid_obj_detect.setModelTypeAsYOLOv3()
vid_obj_detect.setModelPath(r"/home/mikolaj/2ait_tech/sportowe/film/yolo.h5")
vid_obj_detect.loadModel()
person_points = []
ball_points = []
def forFrame(frame_number, output_array, output_count, detected_frame):
for index, object in enumerate(output_array):
if object['name'] == 'sports ball':
ball_points.append((index, object['box_points'][0], object['box_points'][1]))
else:
person_points.append((index, object['box_points'][0], object['box_points'][1]))
detected_vid_obj = vid_obj_detect.detectObjectsFromVideo(
input_file_path = r"football.mp4",
output_file_path = r"changed_football.mp4",
frames_per_second=20,
log_progress=True,
per_frame_function = forFrame,
# per_second_function=forSeconds,
return_detected_frame = True,
)
with open('saved_cordinats', 'w') as file:
file.write(str(person_points))
file.write(str(ball_points))
fig = plt.figure()
ax = fig.add_subplot(111)
# person_points = [(0, 206, 376), (1, 1166, 405), (2, 42, 362), (3, 710, 390), (4, 754, 408), (5, 1083, 378), (0, 216, 383), (1, 64, 367), (2, 709, 387), (3, 750, 400), (4, 1087, 374), (5, 1179, 389)]
# ball_points= [(6, 204, 564), (6, 211, 566), (6, 225, 563)]
COLORS = ['red', 'blue', 'orange', 'pink', 'yellow', 'brown']
LABELS = ['p1', 'p2', 'p3', 'p4', 'p5', 'p6']
persons = {
0:{
'x': [],
'y': []
},
1:{
'x': [],
'y': []
},
2:{
'x': [],
'y': []
},
3:{
'x': [],
'y': []
},
4:{
'x': [],
'y': []
},
5:{
'x': [],
'y': []
},
# 6:{
# 'x': [],
# 'y': []
# },
}
try:
for index, p in enumerate(person_points):
persons[p[0]]['x'].append(p[1])
persons[p[0]]['y'].append(p[2])
for k, v in persons.items():
ax.plot( persons[k]['x'], persons[k]['y'], color=COLORS[k], lw=2, label=LABELS[k])
ball_x = [x[1] for x in ball_points]
ball_y = [x[2] for x in ball_points]
ax.plot(ball_x, ball_y, color='green', lw=2, label='ball')
except:
pass
ax.legend()
plt.savefig('wykres.png')
plt.show()