SMPLify-从二维坐标到人体SMPL和关节转动

之前我一直在SMPLify看代码,确实运行起来了,但是速度太慢了,可能因为用的优化器是chumdy,实际上只所以慢是因为初值不好,用深度学习的方法可以快速猜出来初值,然后用优化的方法进行优化就很快,这个思路就是SPIN

后来我发现SPIN里面的SMPLify写的很好,并且是使用了SMPL-X这个库的,把渲染等工作直接调库,注意SMPL-X也是支持SMPL的。

这里讲解SPIN中的SMPLify

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
import torch
import os

from models.smpl import SMPL
from .losses import camera_fitting_loss, body_fitting_loss
import config
import constants

# For the GMM prior, we use the GMM implementation of SMPLify-X
# https://github.com/vchoutas/smplify-x/blob/master/smplifyx/prior.py
from .prior import MaxMixturePrior
# 建立SMPLify优化器的时候可以一次性优化好几个,设置batch就可以
class SMPLify():
"""Implementation of single-stage SMPLify."""
def __init__(self,
step_size=1e-2,
batch_size=66,
num_iters=100,
focal_length=5000,
device=torch.device('cuda')):

# Store options
self.device = device
self.focal_length = focal_length
self.step_size = step_size

# Ignore the the following joints for the fitting process,这几个不优化
# 优化的二维关节和SMPL的关节名称一一对应,所以代码中看不到冗长的关节对应部分
ign_joints = ['OP Neck', 'OP RHip', 'OP LHip', 'Right Hip', 'Left Hip']
self.ign_joints = [constants.JOINT_IDS[i] for i in ign_joints]
self.num_iters = num_iters
# GMM pose prior
self.pose_prior = MaxMixturePrior(prior_folder='data',
num_gaussians=8,
dtype=torch.float32).to(device)
# Load SMPL model
self.smpl = SMPL(config.SMPL_MODEL_DIR,
batch_size=batch_size,
create_transl=False).to(self.device)
# 关键是这个函数
# 设置初值并传入二维坐标
# 输出为mesh,三维关节坐标,每个关节的转动轴角,形状参数,相机偏移,重投影误差
def __call__(self, init_pose, init_betas, init_cam_t, camera_center, keypoints_2d):
"""Perform body fitting.
Input:
init_pose: SMPL pose estimate
init_betas: SMPL betas estimate
init_cam_t: Camera translation estimate
camera_center: Camera center location
keypoints_2d: Keypoints used for the optimization
Returns:
vertices: Vertices of optimized shape
joints: 3D joints of optimized shape
pose: SMPL pose parameters of optimized shape
betas: SMPL beta parameters of optimized shape
camera_translation: Camera translation
reprojection_loss: Final joint reprojection loss
"""

batch_size = init_pose.shape[0]

# Make camera translation a learnable parameter
camera_translation = init_cam_t.clone()

# Get joint confidence
joints_2d = keypoints_2d[:, :, :2]
joints_conf = keypoints_2d[:, :, -1]

# Split SMPL pose to body pose and global orientation
body_pose = init_pose[:, 3:].detach().clone()
global_orient = init_pose[:, :3].detach().clone()
betas = init_betas.detach().clone()

# Step 1: Optimize camera translation and body orientation
# Optimize only camera translation and body orientation
# requires_grad设置为True则优化这几个参数,第一次就只优化相机参数,只用躯干的四个关节点来优化,所以这里的loss是camera_fitting_loss
body_pose.requires_grad=False
betas.requires_grad=False
global_orient.requires_grad=True
camera_translation.requires_grad = True

camera_opt_params = [global_orient, camera_translation]
camera_optimizer = torch.optim.Adam(camera_opt_params, lr=self.step_size, betas=(0.9, 0.999))

for i in range(self.num_iters):
smpl_output = self.smpl(global_orient=global_orient,
body_pose=body_pose,
betas=betas)
model_joints = smpl_output.joints
loss = camera_fitting_loss(model_joints, camera_translation,
init_cam_t, camera_center,
joints_2d, joints_conf, focal_length=self.focal_length)
camera_optimizer.zero_grad()
loss.backward()
camera_optimizer.step()

# Fix camera translation after optimizing camera
camera_translation.requires_grad = False
# 然后就全优化了
# Step 2: Optimize body joints
# Optimize only the body pose and global orientation of the body
body_pose.requires_grad=True
betas.requires_grad=True
global_orient.requires_grad=True
camera_translation.requires_grad = False
body_opt_params = [body_pose, betas, global_orient]

# For joints ignored during fitting, set the confidence to 0
joints_conf[:, self.ign_joints] = 0.

body_optimizer = torch.optim.Adam(body_opt_params, lr=self.step_size, betas=(0.9, 0.999))
for i in range(self.num_iters):
smpl_output = self.smpl(global_orient=global_orient,
body_pose=body_pose,
betas=betas)
model_joints = smpl_output.joints
loss = body_fitting_loss(body_pose, betas, model_joints, camera_translation, camera_center,
joints_2d, joints_conf, self.pose_prior,
focal_length=self.focal_length)
body_optimizer.zero_grad()
loss.backward()
body_optimizer.step()

# Get final loss value
with torch.no_grad():
smpl_output = self.smpl(global_orient=global_orient,
body_pose=body_pose,
betas=betas, return_full_pose=True)
model_joints = smpl_output.joints
reprojection_loss = body_fitting_loss(body_pose, betas, model_joints, camera_translation, camera_center,
joints_2d, joints_conf, self.pose_prior,
focal_length=self.focal_length,
output='reprojection')

vertices = smpl_output.vertices.detach()
joints = smpl_output.joints.detach()
pose = torch.cat([global_orient, body_pose], dim=-1).detach()
betas = betas.detach()

return vertices, joints, pose, betas, camera_translation, reprojection_loss

def get_fitting_loss(self, pose, betas, cam_t, camera_center, keypoints_2d):
"""Given body and camera parameters, compute reprojection loss value.
Input:
pose: SMPL pose parameters
betas: SMPL beta parameters
cam_t: Camera translation
camera_center: Camera center location
keypoints_2d: Keypoints used for the optimization
Returns:
reprojection_loss: Final joint reprojection loss
"""

batch_size = pose.shape[0]

# Get joint confidence
joints_2d = keypoints_2d[:, :, :2]
joints_conf = keypoints_2d[:, :, -1]
# For joints ignored during fitting, set the confidence to 0
joints_conf[:, self.ign_joints] = 0.

# Split SMPL pose to body pose and global orientation
body_pose = pose[:, 3:]
global_orient = pose[:, :3]

with torch.no_grad():
smpl_output = self.smpl(global_orient=global_orient,
body_pose=body_pose,
betas=betas, return_full_pose=True)
model_joints = smpl_output.joints
reprojection_loss = body_fitting_loss(body_pose, betas, model_joints, cam_t, camera_center,
joints_2d, joints_conf, self.pose_prior,
focal_length=self.focal_length,
output='reprojection')

return reprojection_loss
def body_fitting_loss(body_pose, betas, model_joints, camera_t, camera_center,
joints_2d, joints_conf, pose_prior,
focal_length=5000, sigma=100, pose_prior_weight=4.78,
shape_prior_weight=5, angle_prior_weight=15.2,
output='sum'):
"""
Loss function for body fitting
"""

batch_size = body_pose.shape[0]
rotation = torch.eye(3, device=body_pose.device).unsqueeze(0).expand(batch_size, -1, -1)
projected_joints = perspective_projection(model_joints, rotation, camera_t,
focal_length, camera_center)

# Weighted robust reprojection error
reprojection_error = gmof(projected_joints - joints_2d, sigma)
reprojection_loss = (joints_conf ** 2) * reprojection_error.sum(dim=-1)

# Pose prior loss
pose_prior_loss = (pose_prior_weight ** 2) * pose_prior(body_pose, betas)

# Angle prior for knees and elbows
angle_prior_loss = (angle_prior_weight ** 2) * angle_prior(body_pose).sum(dim=-1)

# Regularizer to prevent betas from taking large values
shape_prior_loss = (shape_prior_weight ** 2) * (betas ** 2).sum(dim=-1)

total_loss = reprojection_loss.sum(dim=-1) + pose_prior_loss + angle_prior_loss + shape_prior_loss

if output == 'sum':
return total_loss.sum()
elif output == 'reprojection':
return reprojection_loss
# 相机参数优化的代码
def camera_fitting_loss(model_joints, camera_t, camera_t_est, camera_center, joints_2d, joints_conf,
focal_length=5000, depth_loss_weight=100):
"""
Loss function for camera optimization.
"""

# Project model joints
batch_size = model_joints.shape[0]
rotation = torch.eye(3, device=model_joints.device).unsqueeze(0).expand(batch_size, -1, -1)
projected_joints = perspective_projection(model_joints, rotation, camera_t,
focal_length, camera_center)

op_joints = ['OP RHip', 'OP LHip', 'OP RShoulder', 'OP LShoulder']
op_joints_ind = [constants.JOINT_IDS[joint] for joint in op_joints]
gt_joints = ['Right Hip', 'Left Hip', 'Right Shoulder', 'Left Shoulder']
gt_joints_ind = [constants.JOINT_IDS[joint] for joint in gt_joints]
reprojection_error_op = (joints_2d[:, op_joints_ind] -
projected_joints[:, op_joints_ind]) ** 2
reprojection_error_gt = (joints_2d[:, gt_joints_ind] -
projected_joints[:, gt_joints_ind]) ** 2

# Check if for each example in the batch all 4 OpenPose detections are valid, otherwise use the GT detections
# OpenPose joints are more reliable for this task, so we prefer to use them if possible
is_valid = (joints_conf[:, op_joints_ind].min(dim=-1)[0][:,None,None] > 0).float()
reprojection_loss = (is_valid * reprojection_error_op + (1-is_valid) * reprojection_error_gt).sum(dim=(1,2))

# Loss that penalizes deviation from depth estimate
depth_loss = (depth_loss_weight ** 2) * (camera_t[:, 2] - camera_t_est[:, 2]) ** 2

total_loss = reprojection_loss + depth_loss
return total_loss.sum()
# 这里重新包装了SMPL模型,重新包装了smplx的库
import torch
import numpy as np
import smplx
from smplx import SMPL as _SMPL
from smplx.body_models import ModelOutput
from smplx.lbs import vertices2joints

import config
import constants

class SMPL(_SMPL):
""" Extension of the official SMPL implementation to support more joints """

def __init__(self, *args, **kwargs):
super(SMPL, self).__init__(*args, **kwargs)
joints = [constants.JOINT_MAP[i] for i in constants.JOINT_NAMES]
J_regressor_extra = np.load(config.JOINT_REGRESSOR_TRAIN_EXTRA)
self.register_buffer('J_regressor_extra', torch.tensor(J_regressor_extra, dtype=torch.float32))
self.joint_map = torch.tensor(joints, dtype=torch.long)

def forward(self, *args, **kwargs):
kwargs['get_skin'] = True
smpl_output = super(SMPL, self).forward(*args, **kwargs)
# 从mesh获得人体坐标
extra_joints = vertices2joints(self.J_regressor_extra, smpl_output.vertices)
joints = torch.cat([smpl_output.joints, extra_joints], dim=1)
joints = joints[:, self.joint_map, :]
output = ModelOutput(vertices=smpl_output.vertices,
global_orient=smpl_output.global_orient,
body_pose=smpl_output.body_pose,
joints=joints,
betas=smpl_output.betas,
full_pose=smpl_output.full_pose)
return output