Python
Scatter plot with subfigures
ddokkddokk
2023. 11. 13. 10:06
반응형
with open('attn_result.pickle', 'rb') as fr:
attn_result = pickle.load(fr)
ref_cam = 'cam0'
# tgt_cam_num = 3
image_ref = attn_result['cam0']['image']
# image_tgt = attn_result['cam'+str(tgt_cam_num)]['image']
Hh, Wh = image_ref.shape[:2]
# Hl Wl nhead ncam npts 2
position = attn_result['cam0']['sample_position']
weight = attn_result['cam0']['attention_weight']
Hl, Wl = position.shape[:2]
scale_y = Hh / Hl
scale_x = Wh / Wl
for xl in range(0, Wl):
for yl in range(0, Hl):
for head in range(8):
fig = plt.figure()
subfigs = fig.subfigures(2, 1)
subfigs[0].add_subplot()
plt.imshow(image_ref)
plt.scatter(xl * scale_x, yl * scale_y, s=200)
plt.title(f'Header # {head}')
for cam in range(6):
image_tgt = attn_result['cam' + str(cam)]['image']
corr_pos = position[yl, xl, head, cam].reshape(8, 2)
corr_wgt = weight[yl, xl, head, cam].reshape(8)
subfigs[1].add_subplot(2, 3, cam+1)
plt.imshow(image_tgt)
# plt.scatter(corr_pos[:, 0]*Wh, corr_pos[:, 1]*Hh, c=corr_wgt, cmap='Greens')
s = 100 * corr_wgt
plt.scatter(corr_pos[:, 0] * Wh, corr_pos[:, 1] * Hh, s=s, c='green')
plt.show()
bp = 0