Files
JustTwirk/moves_videopose3d.py
2025-12-08 20:25:20 +01:00

117 lines
3.3 KiB
Python

import cv2
import torch
import numpy as np
from common.model import TemporalModel
from common.camera import *
# from common.utils import evaluate
from ultralytics import YOLO
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D # not strictly needed in newer matplotlib
import time
# --- 1. Inicjalizacja modelu 3D VideoPose3D ---
model_3d = TemporalModel(
num_joints_in=17,
in_features=2,
num_joints_out=17,
filter_widths=[3,3,3,3],
causal=False
)
chk = torch.load("checkpoint/pretrained_h36m_detectron_coco.bin", map_location='cpu')
model_3d.load_state_dict(chk, strict=False)
model_3d.eval()
# --- 2. Inicjalizacja modelu YOLO (pose keypoints) ---
yolo = YOLO('yolo11s-pose.pt') # używamy najmniejszej wersji dla szybkości
# --- 3. Wczytanie wideo ---
cap = cv2.VideoCapture("input.mp4")
frame_buffer = []
BUFFER_SIZE = 243 # VideoPose3D potrzebuje sekwencji
fig = plt.figure(figsize=(5, 5))
ax = fig.add_subplot(111, projection='3d')
# inicjalizacja scatter i linii szkieletu
scatter = ax.scatter([], [], [], c='r')
skeleton = [ (0, 1), (1, 2), (2, 3), (0, 4), (4, 5), (5, 6), (0, 7), (7, 8), (8, 9), (7, 12), (12, 13), (13, 14), (7, 10), (10, 11), (11, 12) ]
skeleton_lines = []
for _ in skeleton:
line, = ax.plot([], [], [], c='b')
skeleton_lines.append(line)
ax.set_xlim3d(-1, 1)
ax.set_ylim3d(-1, 1)
ax.set_zlim3d(0, 2)
ax.view_init(elev=20, azim=-70)
plt.ion()
plt.show()
while True:
ret, frame = cap.read()
if not ret:
break
# --- 4. Detekcja keypointów z YOLO ---
results = yolo(frame)
if len(results) == 0 or len(results[0].keypoints.xy) == 0:
continue
# Zakładamy 1 osobę na klatkę (dla uproszczenia)
keypoints = results[0].keypoints.xy[0] # shape [17, 2]
keypoints = np.array(keypoints)
# Normalizacja do [0,1] (opcjonalnie zależnie od VideoPose3D)
keypoints[:, 0] /= frame.shape[1]
keypoints[:, 1] /= frame.shape[0]
frame_buffer.append(keypoints)
# --- 5. Jeśli mamy pełną sekwencję, predykcja 3D ---
skeleton = [
(0, 1), (1, 2), (2, 3), (0, 4), (4, 5), (5, 6),
(0, 7), (7, 8), (8, 9), (7, 12), (12, 13), (13, 14),
(7, 10), (10, 11), (11, 12)
]
# --- after getting pred_3d ---
if len(frame_buffer) == BUFFER_SIZE:
seq_2d = torch.tensor(np.array(frame_buffer)).unsqueeze(0).float()
with torch.no_grad():
pred_3d = model_3d(seq_2d)
pose_3d = pred_3d[0, -1].numpy() # [17,3]
# --- 2D overlay ---
# for i, kp in enumerate(frame_buffer[-1]):
# x, y = int(kp[0] * frame.shape[1]), int(kp[1] * frame.shape[0])
# cv2.circle(frame, (x, y), 5, (0, 255, 0), -1)
# cv2.imshow("2D Pose", frame)
# cv2.waitKey(1)
pose_3d = pose_3d[:, [0, 2, 1]] # X, Z, Y
pose_3d[:, 2] *= -1
# --- 3D update ---
xs, ys, zs = pose_3d[:, 0], pose_3d[:, 1], pose_3d[:, 2]
# update scatter
scatter._offsets3d = (xs, ys, zs)
# update skeleton lines
for idx, (a, b) in enumerate(skeleton):
skeleton_lines[idx].set_data([xs[a], xs[b]], [ys[a], ys[b]])
skeleton_lines[idx].set_3d_properties([zs[a], zs[b]])
plt.draw()
plt.pause(0.001)
print(pose_3d.tolist())
frame_buffer.pop(0)
cap.release()
cv2.destroyAllWindows()