156 lines
4.4 KiB
Python
156 lines
4.4 KiB
Python
import cv2 as cv
|
|
import numpy as np
|
|
from glob import glob
|
|
from pprint import pprint
|
|
|
|
|
|
# the directory of the image database
|
|
database_dir = "image.orig"
|
|
|
|
|
|
def countCorrectImage(in_array):
|
|
output = 0
|
|
for item in in_array:
|
|
if item[0] >= 100 and item[0] < 200:
|
|
output += 1
|
|
return output
|
|
|
|
|
|
# Compute pixel-by-pixel difference and return the sum
|
|
def compareImgs(img1, img2):
|
|
# resize img2 to img1
|
|
img2 = cv.resize(img2, (img1.shape[1], img1.shape[0]))
|
|
diff = cv.absdiff(img1, img2)
|
|
return diff.sum()
|
|
|
|
|
|
def compareImgs_hist(img1, img2):
|
|
width, height = img1.shape[1], img1.shape[0]
|
|
img2 = cv.resize(img2, (width, height))
|
|
num_bins = 10
|
|
hist1 = [0] * num_bins
|
|
hist2 = [0] * num_bins
|
|
bin_width = 255.0 / num_bins + 1e-4
|
|
# compute histogram from scratch
|
|
|
|
# for w in range(width):
|
|
# for h in range(height):
|
|
# hist1[int(img1[h, w] / bin_width)] += 1
|
|
# hist2[int(img2[h, w] / bin_width)] += 1
|
|
|
|
# compute histogram by using opencv function
|
|
# https://docs.opencv.org/4.x/d6/dc7/group__imgproc__hist.html#ga4b2b5fd75503ff9e6844cc4dcdaed35d
|
|
|
|
hist1 = cv.calcHist([img1], [0], None, [num_bins], [0, 255])
|
|
hist2 = cv.calcHist([img2], [0], None, [num_bins], [0, 255])
|
|
sum = 0
|
|
for i in range(num_bins):
|
|
sum += abs(hist1[i] - hist2[i])
|
|
return sum / float(width * height)
|
|
|
|
|
|
def histLowerBoundAndUpperBound(retrieval_results, lower_bound, uper_bound):
|
|
|
|
filtered_results = [
|
|
item for item in retrieval_results if lower_bound <= item[2] <= uper_bound
|
|
]
|
|
|
|
return filtered_results
|
|
|
|
|
|
def retrieval(src_input, database):
|
|
output = []
|
|
min_diff = 1e50
|
|
|
|
# change the image to gray scale
|
|
src_gray = cv.cvtColor(src_input, cv.COLOR_BGR2GRAY)
|
|
|
|
for img in database:
|
|
# read image
|
|
img_rgb = cv.imread(img)
|
|
# convert to gray scale
|
|
img_gray = cv.cvtColor(img_rgb, cv.COLOR_BGR2GRAY)
|
|
# compare the two images
|
|
diff = compareImgs(src_gray, img_gray)
|
|
|
|
output.append(
|
|
[
|
|
int(img.replace(database_dir + "/", "").replace(".jpg", "")),
|
|
img,
|
|
diff,
|
|
]
|
|
)
|
|
|
|
return output
|
|
|
|
|
|
def retrievalTest():
|
|
output = []
|
|
|
|
src_input = cv.imread("beach.jpg")
|
|
# precision
|
|
retrieval_precision_results = retrieval(
|
|
src_input, sorted(glob(database_dir + "/*.jpg"))
|
|
)
|
|
# recall
|
|
retrieval_recall_test_result = retrieval(
|
|
src_input, sorted(glob(database_dir + "/1**.jpg"))
|
|
)
|
|
|
|
# sliced_array = retrieval_results[100:200]
|
|
# third_element = [item[2] for item in sliced_array]
|
|
# lower_limit = min(third_element)
|
|
# upper_limit = max(third_element)
|
|
# (np.uint64(30241636), np.uint64(62158322))
|
|
|
|
# filtered_results = [
|
|
# item
|
|
# for item in retrieval_results
|
|
# if np.uint64(30241636) <= item[2] <= np.uint64(62158322)
|
|
# ]
|
|
|
|
# number of retrieved images
|
|
# whatever the image classified using my strategy
|
|
lu_bound_test_result = histLowerBoundAndUpperBound(
|
|
retrieval_precision_results, np.uint64(30241636), np.uint64(62158322)
|
|
)
|
|
|
|
# number of retrived images that are from the correct category
|
|
precision_correct_count = countCorrectImage(lu_bound_test_result)
|
|
|
|
# Precision
|
|
number_of_retrived_image = len(lu_bound_test_result)
|
|
precision_pct = str((precision_correct_count / number_of_retrived_image) * 100)
|
|
|
|
print("correct from retrived image: ".ljust(50) + str(precision_correct_count))
|
|
print("number of retrived image: ".ljust(50) + str(number_of_retrived_image))
|
|
print("precision/correct rate (%, target 60%): ".ljust(50) + precision_pct)
|
|
|
|
# number of retrived images that are from the correct category
|
|
|
|
# total number of images in the target category of the dataset
|
|
# pprint("recall:" + str(recall_correct_count / 100))
|
|
|
|
return [precision_correct_count, number_of_retrived_image, precision_pct]
|
|
|
|
|
|
def test_performance(number):
|
|
# print("1: Image retrieval demo")
|
|
# print("2: SIFT demo")
|
|
|
|
if number == 1:
|
|
metrics = retrievalTest()
|
|
pprint(metrics)
|
|
# sliced_array = retrieval_results[101:200]
|
|
# third_element = [item[2] for item in sliced_array]
|
|
# max_diff = max(third_element)
|
|
# min_diff = min(third_element)
|
|
|
|
# pprint([max_diff, min_diff])
|
|
else:
|
|
print("Invalid input")
|
|
exit()
|
|
|
|
|
|
test_performance(1)
|