当前位置:首页> AI教程> 使用LSH算法实现简单的图像搜索功能

使用LSH算法实现简单的图像搜索功能

释放双眼,带上耳机,听听看~!
本文介绍了使用LSH算法在预训练图像分类模型基础上实现简单图像搜索功能的过程,涉及数据准备、模型准备和展示效果。通过该教程,读者可以学习LSH算法在图像搜索中的应用,以及使用预训练模型进行特征提取的方法。

概要

获取相似的图像是当下搜索引擎的一个重要功能,在本次任务中,我使用 LSH 算法,在预训练图像分类模型 Bit 基础上,实现简单的图像搜索功能。实现过程比较简单,容易理解,是个值得上手练习的案例。

数据准备

  • 我们这里用到的数据是 food101 数据集,下载可能会需要很长的时间,并且要保证网络稳定。我们从中选取了三种类型的数据进行本次的任务,包括 32 (dumplings)53 (hamburger)55 (hot_dog)三种类别的图像。
  • 在进行数据处理过程中,主要是将图像调整为大小为 (256, 256) 的尺寸,并且将图像数据归一化到 [0, 1] 范围内。
  • 处理完数据之后进行混洗,选出 80% 作为训练集,20% 作为验证集。
import matplotlib.pyplot as plt
import tensorflow as tf
import numpy as np
import time
import tensorflow_datasets as tfds
import tensorflow_hub as hub
from tqdm import tqdm
import random

train_ds = tfds.load( "food101", split="train", as_supervised=True)   # 75750
ds = []
for (image, label) in train_ds:
    label = label.numpy()
    if label == 55 or label == 53 or label == 32:
        image = tf.image.resize(image, (256,256))
        image = image / 255.
        image = image.numpy()
        ds.append([image, label])
random.shuffle(ds)
N = int(len(ds)*0.8)
train_images, train_labels= zip(*ds[:N])
val_images, val_labels = zip(*ds[N:]) 

这里是一个工具类,主要是为了随机选取 16 张图像,展示成 4×4 的图像形式。

def show(images):
    if type(images) == tf.Tensor:
        images = images.numpy()
    images = images[:16]
    plt.figure(figsize=(5, 5))
    for i in range(len(images)):
        image = images[i]
        plt.subplot(4, 4, i + 1)
        plt.imshow(image)
        plt.axis('off')
    plt.show()

show(train_images)

效果如下:

使用LSH算法实现简单的图像搜索功能

模型准备

  • 使用hub.KerasLayer函数从指定的 URL 加载预训练模型,该模型是 Bit 模型在 ImageNet 21K 数据集上进行预训练后的结果,专门用于食物分类任务,我这里主要是将其作为特征提取器。
  • 使用tf.keras.Sequential创建一个序列模型,序列模型是一系列网络层的线性堆叠。首先,使用tf.keras.layers.Input定义一个输入层,指定输入图像的形状为 (256, 256, 3) 。接下来,将之前加载的 Bit 模型(bit_model)添加到序列模型中,作为特征提取器。最后,添加一个归一化层tf.keras.layers.Normalization,用于对特征向量进行归一化处理,使其具有均值为 0 ,方差为 1 的标准正态分布。
bit_model  = hub.KerasLayer("https://tfhub.dev/google/experts/bit/r50x1/in21k/food/1")
embedding_model = tf.keras.Sequential(
    [
        tf.keras.layers.Input((256, 256, 3)),
        bit_model,
        tf.keras.layers.Normalization(mean=0, variance=1, name='normalization'),
    ],
    name="embedding_model",
)

工具

bool2int 函数

该函数的作用是将一个布尔数组转换为对应的整数值。函数的输入参数x是一个布尔数组,表示一个哈希码的二进制形式。通过迭代布尔数组中的每个元素,对为True的位置进行位运算,将其对应的位设置为 1 ,并将所有位的值相加,得到最终的整数值。

def bool2int(x):
    y = 0
    for i, j in enumerate(x):
        if j:
            y += 1 << i
    return y

hash_func 函数

  • 该函数将输入的嵌入向量进行哈希处理。函数的输入参数embedding是一个图像经过编码的向量大小是 (1,2048) ,参数random_vectors是随机向量,大小是 (2048,8)
  • 将两个向量进行相乘得到一个结果向量大小是 (1, 8),将其转换为布尔形式。
  • 最后,调用bool2int函数将布尔数组转换为对应的整数哈希码,并将结果以列表的形式返回。

这段代码主要用于将嵌入向量进行哈希处理,将连续的嵌入空间映射到离散的哈希空间中,以便进行快速的相似度搜索或索引。

def hash_func(embedding, random_vectors):
    embedding = np.array(embedding)

    # Random projection.
    bools = np.dot(embedding, random_vectors) > 0
    return [bool2int(bool_vec) for bool_vec in bools]

Table 类

Table 类主要是用于构建哈希表,并实现了添加数据和查询数据的功能。random_vectors 是一个随机生成的大小为 (8,2048) 的向量,用于和图像的特征向量结合生成哈希码。对于每个图像在经过哈希之后,将其添加到 table 字典中,可能同一个哈希值对应多张图像

  • 初始化Table对象时,需要指定哈希表的大小(hash_size)和特征向量的维度(dim)。random_vectors 是一个随机生成的大小为 (dim, hash_size) 的向量,用于和图像的特征向量结合生成哈希码。table 字典用于用于存储数据。

  • add方法用于向哈希表中添加图像。每个图像使用其 id 和其 label 来命名,如 0_55 。通过调用hash_func函数,将图像特征向量和随机向量结合映射为哈希码。在 table 中如果哈希码对应的桶已存在,则将当前条目添加到桶中;否则,创建新的桶,并将条目添加到桶中。

  • query方法用于根据图像特征向量进行查询操作,返回与特征向量相似的条目列表。输入参数为图像经过模型提取的特征向量vectors。通过调用hash_func函数计算出哈希码。遍历哈希码列表,对每个哈希码,在哈希表中查找对应的桶。如果桶存在,则将桶中的条目添加到结果列表中,最后返回结果列表。

class Table:
    def __init__(self, hash_size, dim):
        self.table = {}
        self.hash_size = hash_size
        self.random_vectors = np.random.randn(hash_size, dim).T

    def add(self, id, vectors, label):
        entry = {"id_label": str(id) + "_" + str(label)}
        hashes = hash_func(vectors, self.random_vectors)
        for h in hashes:
            if h in self.table:
                self.table[h].append(entry)
            else:
                self.table[h] = [entry]

    def query(self, vectors):
        hashes = hash_func(vectors, self.random_vectors)
        results = []
        for h in hashes:
            if h in self.table:
                results.extend(self.table[h])
        return results

LSH

这段代码定义了一个LSH 类,它是基于多个哈希表构建的索引结构。

  • 初始化LSH对象时,需要指定哈希表的大小(hash_size)、特征向量的维度(dim)和哈希表的数量(num_tables)。创建num_tablesTable对象,并将其存储在列表tables中。

  • add方法用于向所有哈希表中添加内容。输入参数包括idvectors(图像特征向量)、label(标签)。遍历每个哈希表,调用对应的Table对象的add方法,将该对象经过计算添加到每一个的哈希表中,由于每个哈希表中的 random_vectors 不同,所以计算出来的哈希值也不同。

  • query方法用于根据图像特征向量进行查询操作,返回与特征向量相似的条目列表。输入参数为图像特征向量vectors。遍历每个哈希表,调用每个哈希表的query方法,获取相似的图像列表,追加到最终结果中。通过多个的哈希表索引,可以提高相似度搜索的效率。

class LSH:
    def __init__(self, hash_size, dim, num_tables):
        self.num_tables = num_tables
        self.tables = []
        for i in range(self.num_tables):
            self.tables.append(Table(hash_size, dim))

    def add(self, id, vectors, label):
        for table in self.tables:
            table.add(id, vectors, label)

    def query(self, vectors):
        results = []
        for table in self.tables:
            results.extend(table.query(vectors))
        return results

BuildLSHTable

这段代码定义了一个BuildLSHTable类,用于构建和查询 LSH 哈希表。

  • 初始化BuildLSHTable对象时,需要指定哈希表的大小(hash_size)、特征向量的维度(dim)和哈希表的数量(num_tables)。prediction_model 是一个用于提取特征向量的模型,也就是我们前面定义的embedding_model 。 concrete_function参数用于指定是否使用具体函数来提取特征向量。

  • train方法用于填充 LSH 对象。将训练数据中的每张图像使用prediction_model提取图像的特征向量。然后调用LSH对象的add方法,将特征向量、标签和唯一标识符添加到 LSH 对象中的每一个哈希表中。

  • query方法用于在 LSH 对象,找到与输入图像相似的图像。使用prediction_model提取输入图像的特征向量。调用LSH对象的query方法,获取与特征向量相似的图像。统计相同哈希值对应的相似图片的个数,并对计数结果除 dim 作为该哈希值的相似度。


class BuildLSHTable:
    def __init__( self, prediction_model, concrete_function=False,  hash_size=8, dim=2048, num_tables=10, ):
        self.hash_size = hash_size
        self.dim = dim
        self.num_tables = num_tables
        self.lsh = LSH(self.hash_size, self.dim, self.num_tables)
        self.prediction_model = prediction_model
        self.concrete_function = concrete_function

    def train(self, training_files):
        for id, training_file in enumerate(training_files):
            image, label = training_file
            if len(image.shape) < 4:
                image = image[None, ...]
            if self.concrete_function:
                features = self.prediction_model(tf.constant(image))[ "normalization" ].numpy()
            else:
                features = self.prediction_model.predict(image)
            self.lsh.add(id, features, label)

    def query(self, image, verbose=True):
        if len(image.shape) < 4:
            image = image[None, ...]
        if self.concrete_function:
            features = self.prediction_model(tf.constant(image))[  "normalization" ].numpy()
        else:
            features = self.prediction_model.predict(image)
        results = self.lsh.query(features)
        if verbose:
            print("Matches:", len(results))
        counts = {}
        for r in results:
            if r["id_label"] in counts:
                counts[r["id_label"]] += 1
            else:
                counts[r["id_label"]] = 1
        for k in counts:
            counts[k] = float(counts[k]) / self.dim
        return counts

训练

这里主要是使用训练数据进行训练,也就是哈希表的填充,过程比较简单。

training_files = zip(train_images, train_labels)
lsh_builder = BuildLSHTable(embedding_model)
lsh_builder.train(training_files)

打印:

1/1 [==============================] - 1s 722ms/step
1/1 [==============================] - 0s 20ms/step
1/1 [==============================] - 0s 21ms/step
...
1/1 [==============================] - 0s 21ms/step
1/1 [==============================] - 0s 17ms/step

效果展示

这里主要是随机对 5 张图片进行了相似图片的搜索,我们可以直接看结果图。第一列是输入的 5 张图片,后面几列是根据相似度从高到低展示出来的图片,效果还是可以的,对于汉堡、包子、热狗基本能搜索出来相似的图片。


images = train_images + val_images
labels = train_labels + val_labels

def plot_images(images, labels):
    plt.figure(figsize=(20, 10))
    columns = 5
    for (i, image) in enumerate(images):
        ax = plt.subplot(len(images) // columns + 1, columns, i + 1)
        if i == 0:
            ax.set_title("Query Imagen" + "Label: {}".format(labels[i]))
        else:
            ax.set_title("Similar Image # " + str(i) + "nLabel: {}".format(labels[i]))
        plt.imshow(image.astype("float"))
        plt.axis("off")


def visualize_lsh(lsh_class):
    idx = np.random.choice(len(val_images))
    image = val_images[idx]
    label = val_labels[idx]
    results = lsh_class.query(image)

    candidates = []
    labels = []
    overlaps = []

    for idx, r in enumerate(sorted(results, key=results.get, reverse=True)):
        if idx == 4:
            break
        image_id, label = r.split("_")[0], r.split("_")[1]
        candidates.append(images[int(image_id)])
        labels.append(label)
        overlaps.append(results[r])

    candidates.insert(0, image)
    labels.insert(0, label)

    plot_images(candidates, labels)

for _ in range(5):
    visualize_lsh(lsh_builder)

使用LSH算法实现简单的图像搜索功能

使用LSH算法实现简单的图像搜索功能

使用LSH算法实现简单的图像搜索功能

使用LSH算法实现简单的图像搜索功能

使用LSH算法实现简单的图像搜索功能

本网站的内容主要来自互联网上的各种资源,仅供参考和信息分享之用,不代表本网站拥有相关版权或知识产权。如您认为内容侵犯您的权益,请联系我们,我们将尽快采取行动,包括删除或更正。
AI教程

深度学习模型压缩方法综述

2023-12-17 17:48:14

AI教程

Pytorch实战:波士顿房价预测模型搭建

2023-12-17 18:08:14

个人中心
购物车
优惠劵
今日签到
有新私信 私信列表
搜索