본문 바로가기

Python

Scatter plot with subfigures

반응형
with open('attn_result.pickle', 'rb') as fr:
    attn_result = pickle.load(fr)

ref_cam = 'cam0'
# tgt_cam_num = 3

image_ref = attn_result['cam0']['image']
# image_tgt = attn_result['cam'+str(tgt_cam_num)]['image']
Hh, Wh = image_ref.shape[:2]

# Hl Wl nhead ncam npts 2
position = attn_result['cam0']['sample_position']
weight = attn_result['cam0']['attention_weight']
Hl, Wl = position.shape[:2]

scale_y = Hh / Hl
scale_x = Wh / Wl

for xl in range(0, Wl):
    for yl in range(0, Hl):
        for head in range(8):

            fig = plt.figure()
            subfigs = fig.subfigures(2, 1)
            subfigs[0].add_subplot()
            plt.imshow(image_ref)
            plt.scatter(xl * scale_x, yl * scale_y, s=200)
            plt.title(f'Header # {head}')

            for cam in range(6):
                image_tgt = attn_result['cam' + str(cam)]['image']
                corr_pos = position[yl, xl, head, cam].reshape(8, 2)
                corr_wgt = weight[yl, xl, head, cam].reshape(8)
                subfigs[1].add_subplot(2, 3, cam+1)
                plt.imshow(image_tgt)
                # plt.scatter(corr_pos[:, 0]*Wh, corr_pos[:, 1]*Hh, c=corr_wgt, cmap='Greens')
                s = 100 * corr_wgt
                plt.scatter(corr_pos[:, 0] * Wh, corr_pos[:, 1] * Hh, s=s, c='green')


            plt.show()

            bp = 0