图像检索评价指标:PR曲线的计算与绘制

图像检索评价指标:PR曲线的计算与绘制

大家好,又见面了,我是你们的朋友全栈君。

代码语言:javascript复制# @file name : test2.py

# @brief : 如何绘制PR曲线

# @author : liupc

# @date : 2021/8/2

import numpy as np

from tqdm import tqdm

import matplotlib.pyplot as plt

#计算汉明距离。有几位不同,距离就为几。

def CalcHammingDist(B1, B2):

q = B2.shape[1]

distH = 0.5 * (q - np.dot(B1, B2.transpose()))

return distH

draw_range = [1,2,3,4,5,6,7]

def pr_curve(rF, qF, rL, qL, draw_range=draw_range):

#rf:galleryBinary

#qF: queryBinary

#rL: galleryLabel。7行3列。

#qL: queryLabel。3行3列。

n_query = qF.shape[0] #多少个查询,3

Gnd = (np.dot(qL, rL.transpose()) > 0).astype(np.float32)

'''

print(Gnd) #是一个3行7列的数组。第一行代表gallery的7个元素是否与query[0]同类;第二行代表gallery的7个元素是否与query[1]同类。。。

[[0. 1. 1. 0. 0. 0. 1.] #gallery[0]与query[0]不同类;gallery[1]与query[0]同类;gallery[2]与query[0]同类;gallery[3]与query[0]不同类。。。

[1. 1. 1. 0. 1. 0. 1.]

[0. 0. 1. 1. 0. 1. 0.]]

'''

Rank = np.argsort(CalcHammingDist(qF, rF)) #是一个3行7列的数组。

'''

print(Rank) #是一个3行7列的数组。

[[3 0 2 5 1 4 6] #gallery的七个元组中,与query[0]最近的元素是gallery[3],其次是gallery[0],再次是gallery[2]。。。

[6 1 4 0 2 5 3]

[0 4 2 5 6 1 3]]

'''

P, R = [], []

for k in tqdm(draw_range): #比如k=5

p = np.zeros(n_query) #[0, 0, 0] 分别是query[0]的acc&k, query[1]的acc&k, query[2]的acc&k

r = np.zeros(n_query) #[0, 0, 0] 分别是query[0]的recall&k, query[1]的recall&k, query[2]的recall&k

for it in range(n_query): #比如it=0

gnd = Gnd[it] #[0. 1. 1. 0. 0. 0. 1.]

gnd_all = np.sum(gnd) #3,为了求召回率

if gnd_all == 0: #如果没有对的,那准确率和召回率肯定都是0,不用继续求了

continue

asc_id = Rank[it][:k] #[3 0 2 5 1]

gnd = gnd[asc_id] #[0 0 1 0 1]

gnd_r = np.sum(gnd) #前k个结果中对了2个。

p[it] = gnd_r / k #准确率:2/5

r[it] = gnd_r / gnd_all #召回率:2/3

P.append(np.mean(p))

R.append(np.mean(r))

#绘制PR曲线

plt.plot(R, P, linestyle="-", marker='D', label="DSH")

plt.grid(True)

plt.xlim(0, 1)

plt.ylim(0, 1)

plt.xlabel('recall')

plt.ylabel('precision')

plt.legend() # 加图例

plt.show()

return P, R

if __name__=='__main__':

queryBinary = np.array([[1,-1,1,1],[-1,1,-1,-1],[1,-1,-1,-1]])

galleryBinary = np.array([[ 1,-1,-1,-1],

[-1, 1, 1,-1],

[ 1, 1, 1,-1],

[-1,-1, 1, 1],

[ 1, 1,-1,-1],

[ 1, 1, 1,-1],

[-1, 1,-1,-1]])

queryLabel = np.array([[1,0,0],

[1,1,0],

[0,0,1]], dtype=np.int64)

galleryLabel = np.array([[0,1,0],

[1,1,0],

[1,0,1],

[0,0,1],

[0,1,0],

[0,0,1],

[1,1,0]], dtype=np.int64)

P, R = pr_curve(galleryBinary, queryBinary, galleryLabel, queryLabel)

print(f'Precision Recall Curve data:\n"DSH":[{P},{R}],')发布者:全栈程序员栈长,转载请注明出处:https://javaforall.cn/149718.html原文链接:https://javaforall.cn

🌟 相关推荐

女排世界杯:阿根廷3-0击败意大利晋级八强
365bet注册网址

女排世界杯:阿根廷3-0击败意大利晋级八强

📅 10-08 👁️ 6025
电热水壶多久换一次
beat365上不去

电热水壶多久换一次

📅 07-10 👁️ 2187
《英雄联盟》theshy直播位置介绍
365bet真人网

《英雄联盟》theshy直播位置介绍

📅 07-27 👁️ 1647