mirror of https://github.com/nomic-ai/gpt4all
LocalDocs version 2 with text embeddings.
parent
d4ce9f4a7c
commit
371e2a5cbc
@ -0,0 +1,190 @@
|
||||
#include "embeddings.h"
|
||||
|
||||
#include <QFile>
|
||||
#include <QFileInfo>
|
||||
#include <QDebug>
|
||||
|
||||
#include "mysettings.h"
|
||||
#include "hnswlib/hnswlib.h"
|
||||
|
||||
#define EMBEDDINGS_VERSION 0
|
||||
|
||||
const int s_dim = 384; // Dimension of the elements
|
||||
const int s_ef_construction = 200; // Controls index search speed/build speed tradeoff
|
||||
const int s_M = 16; // Tightly connected with internal dimensionality of the data
|
||||
// strongly affects the memory consumption
|
||||
|
||||
Embeddings::Embeddings(QObject *parent)
|
||||
: QObject(parent)
|
||||
, m_space(nullptr)
|
||||
, m_hnsw(nullptr)
|
||||
{
|
||||
m_filePath = MySettings::globalInstance()->modelPath()
|
||||
+ QString("embeddings_v%1.dat").arg(EMBEDDINGS_VERSION);
|
||||
}
|
||||
|
||||
Embeddings::~Embeddings()
|
||||
{
|
||||
delete m_hnsw;
|
||||
m_hnsw = nullptr;
|
||||
delete m_space;
|
||||
m_space = nullptr;
|
||||
}
|
||||
|
||||
bool Embeddings::load()
|
||||
{
|
||||
QFileInfo info(m_filePath);
|
||||
if (!info.exists()) {
|
||||
qWarning() << "ERROR: loading embeddings file does not exist" << m_filePath;
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!info.isReadable()) {
|
||||
qWarning() << "ERROR: loading embeddings file is not readable" << m_filePath;
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!info.isWritable()) {
|
||||
qWarning() << "ERROR: loading embeddings file is not writeable" << m_filePath;
|
||||
return false;
|
||||
}
|
||||
|
||||
try {
|
||||
m_space = new hnswlib::InnerProductSpace(s_dim);
|
||||
m_hnsw = new hnswlib::HierarchicalNSW<float>(m_space, m_filePath.toStdString(), s_M, s_ef_construction);
|
||||
} catch (const std::exception &e) {
|
||||
qWarning() << "ERROR: could not load hnswlib index:" << e.what();
|
||||
return false;
|
||||
}
|
||||
return isLoaded();
|
||||
}
|
||||
|
||||
bool Embeddings::load(qint64 maxElements)
|
||||
{
|
||||
try {
|
||||
m_space = new hnswlib::InnerProductSpace(s_dim);
|
||||
m_hnsw = new hnswlib::HierarchicalNSW<float>(m_space, maxElements, s_M, s_ef_construction);
|
||||
} catch (const std::exception &e) {
|
||||
qWarning() << "ERROR: could not create hnswlib index:" << e.what();
|
||||
return false;
|
||||
}
|
||||
return isLoaded();
|
||||
}
|
||||
|
||||
bool Embeddings::save()
|
||||
{
|
||||
if (!isLoaded())
|
||||
return false;
|
||||
try {
|
||||
m_hnsw->saveIndex(m_filePath.toStdString());
|
||||
} catch (const std::exception &e) {
|
||||
qWarning() << "ERROR: could not save hnswlib index:" << e.what();
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool Embeddings::isLoaded() const
|
||||
{
|
||||
return m_hnsw != nullptr;
|
||||
}
|
||||
|
||||
bool Embeddings::fileExists() const
|
||||
{
|
||||
QFileInfo info(m_filePath);
|
||||
return info.exists();
|
||||
}
|
||||
|
||||
bool Embeddings::resize(qint64 size)
|
||||
{
|
||||
if (!isLoaded()) {
|
||||
qWarning() << "ERROR: attempting to resize an embedding when the embeddings are not open!";
|
||||
return false;
|
||||
}
|
||||
|
||||
Q_ASSERT(m_hnsw);
|
||||
try {
|
||||
m_hnsw->resizeIndex(size);
|
||||
} catch (const std::exception &e) {
|
||||
qWarning() << "ERROR: could not resize hnswlib index:" << e.what();
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool Embeddings::add(const std::vector<float> &embedding, qint64 label)
|
||||
{
|
||||
if (!isLoaded()) {
|
||||
bool success = load(500);
|
||||
if (!success) {
|
||||
qWarning() << "ERROR: attempting to add an embedding when the embeddings are not open!";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
Q_ASSERT(m_hnsw);
|
||||
if (m_hnsw->cur_element_count + 1 > m_hnsw->max_elements_) {
|
||||
if (!resize(m_hnsw->max_elements_ + 500)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
try {
|
||||
m_hnsw->addPoint(embedding.data(), label, false);
|
||||
} catch (const std::exception &e) {
|
||||
qWarning() << "ERROR: could not add embedding to hnswlib index:" << e.what();
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void Embeddings::remove(qint64 label)
|
||||
{
|
||||
if (!isLoaded()) {
|
||||
qWarning() << "ERROR: attempting to remove an embedding when the embeddings are not open!";
|
||||
return;
|
||||
}
|
||||
|
||||
Q_ASSERT(m_hnsw);
|
||||
try {
|
||||
m_hnsw->markDelete(label);
|
||||
} catch (const std::exception &e) {
|
||||
qWarning() << "ERROR: could not add remove embedding from hnswlib index:" << e.what();
|
||||
}
|
||||
}
|
||||
|
||||
void Embeddings::clear()
|
||||
{
|
||||
delete m_hnsw;
|
||||
m_hnsw = nullptr;
|
||||
delete m_space;
|
||||
m_space = nullptr;
|
||||
}
|
||||
|
||||
std::vector<qint64> Embeddings::search(const std::vector<float> &embedding, int K)
|
||||
{
|
||||
if (!isLoaded())
|
||||
return std::vector<qint64>();
|
||||
|
||||
Q_ASSERT(m_hnsw);
|
||||
std::priority_queue<std::pair<float, hnswlib::labeltype>> result;
|
||||
try {
|
||||
result = m_hnsw->searchKnn(embedding.data(), K);
|
||||
} catch (const std::exception &e) {
|
||||
qWarning() << "ERROR: could not search hnswlib index:" << e.what();
|
||||
return std::vector<qint64>();
|
||||
}
|
||||
|
||||
std::vector<qint64> neighbors;
|
||||
neighbors.reserve(K);
|
||||
|
||||
while(!result.empty()) {
|
||||
neighbors.push_back(result.top().second);
|
||||
result.pop();
|
||||
}
|
||||
|
||||
// Reverse the neighbors, as the top of the priority queue is the farthest neighbor.
|
||||
std::reverse(neighbors.begin(), neighbors.end());
|
||||
|
||||
return neighbors;
|
||||
}
|
@ -0,0 +1,45 @@
|
||||
#ifndef EMBEDDINGS_H
|
||||
#define EMBEDDINGS_H
|
||||
|
||||
#include <QObject>
|
||||
|
||||
namespace hnswlib {
|
||||
template <typename T>
|
||||
class HierarchicalNSW;
|
||||
class InnerProductSpace;
|
||||
}
|
||||
|
||||
class Embeddings : public QObject
|
||||
{
|
||||
Q_OBJECT
|
||||
public:
|
||||
Embeddings(QObject *parent);
|
||||
virtual ~Embeddings();
|
||||
|
||||
bool load();
|
||||
bool load(qint64 maxElements);
|
||||
bool save();
|
||||
bool isLoaded() const;
|
||||
bool fileExists() const;
|
||||
bool resize(qint64 size);
|
||||
|
||||
// Adds the embedding and returns the label used
|
||||
bool add(const std::vector<float> &embedding, qint64 label);
|
||||
|
||||
// Removes the embedding at label by marking it as unused
|
||||
void remove(qint64 label);
|
||||
|
||||
// Clears the embeddings
|
||||
void clear();
|
||||
|
||||
// Performs a nearest neighbor search of the embeddings and returns a vector of labels
|
||||
// for the K nearest neighbors of the given embedding
|
||||
std::vector<qint64> search(const std::vector<float> &embedding, int K);
|
||||
|
||||
private:
|
||||
QString m_filePath;
|
||||
hnswlib::InnerProductSpace *m_space;
|
||||
hnswlib::HierarchicalNSW<float> *m_hnsw;
|
||||
};
|
||||
|
||||
#endif // EMBEDDINGS_H
|
@ -0,0 +1,64 @@
|
||||
#include "embllm.h"
|
||||
#include "modellist.h"
|
||||
|
||||
EmbeddingLLM::EmbeddingLLM()
|
||||
: QObject{nullptr}
|
||||
, m_model{nullptr}
|
||||
{
|
||||
}
|
||||
|
||||
EmbeddingLLM::~EmbeddingLLM()
|
||||
{
|
||||
delete m_model;
|
||||
m_model = nullptr;
|
||||
}
|
||||
|
||||
bool EmbeddingLLM::loadModel()
|
||||
{
|
||||
const EmbeddingModels *embeddingModels = ModelList::globalInstance()->embeddingModels();
|
||||
if (!embeddingModels->count())
|
||||
return false;
|
||||
|
||||
const ModelInfo defaultModel = embeddingModels->defaultModelInfo();
|
||||
|
||||
QString filePath = defaultModel.dirpath + defaultModel.filename();
|
||||
QFileInfo fileInfo(filePath);
|
||||
if (!fileInfo.exists()) {
|
||||
qWarning() << "WARNING: Could not load sbert because file does not exist";
|
||||
m_model = nullptr;
|
||||
return false;
|
||||
}
|
||||
|
||||
m_model = LLModel::Implementation::construct(filePath.toStdString(), "auto");
|
||||
bool success = m_model->loadModel(filePath.toStdString());
|
||||
if (!success) {
|
||||
qWarning() << "WARNING: Could not load sbert";
|
||||
delete m_model;
|
||||
m_model = nullptr;
|
||||
return false;
|
||||
}
|
||||
|
||||
if (m_model->implementation().modelType()[0] != 'B') {
|
||||
qWarning() << "WARNING: Model type is not sbert";
|
||||
delete m_model;
|
||||
m_model = nullptr;
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool EmbeddingLLM::hasModel() const
|
||||
{
|
||||
return m_model;
|
||||
}
|
||||
|
||||
std::vector<float> EmbeddingLLM::generateEmbeddings(const QString &text)
|
||||
{
|
||||
if (!hasModel() && !loadModel()) {
|
||||
qWarning() << "WARNING: Could not load sbert model for embeddings";
|
||||
return std::vector<float>();
|
||||
}
|
||||
|
||||
Q_ASSERT(hasModel());
|
||||
return m_model->embedding(text.toStdString());
|
||||
}
|
@ -0,0 +1,27 @@
|
||||
#ifndef EMBLLM_H
|
||||
#define EMBLLM_H
|
||||
|
||||
#include <QObject>
|
||||
#include <QThread>
|
||||
#include "../gpt4all-backend/llmodel.h"
|
||||
|
||||
class EmbeddingLLM : public QObject
|
||||
{
|
||||
Q_OBJECT
|
||||
public:
|
||||
EmbeddingLLM();
|
||||
virtual ~EmbeddingLLM();
|
||||
|
||||
bool hasModel() const;
|
||||
|
||||
public Q_SLOTS:
|
||||
std::vector<float> generateEmbeddings(const QString &text);
|
||||
|
||||
private:
|
||||
bool loadModel();
|
||||
|
||||
private:
|
||||
LLModel *m_model = nullptr;
|
||||
};
|
||||
|
||||
#endif // EMBLLM_H
|
@ -0,0 +1,167 @@
|
||||
#pragma once
|
||||
#include <unordered_map>
|
||||
#include <fstream>
|
||||
#include <mutex>
|
||||
#include <algorithm>
|
||||
#include <assert.h>
|
||||
|
||||
namespace hnswlib {
|
||||
template<typename dist_t>
|
||||
class BruteforceSearch : public AlgorithmInterface<dist_t> {
|
||||
public:
|
||||
char *data_;
|
||||
size_t maxelements_;
|
||||
size_t cur_element_count;
|
||||
size_t size_per_element_;
|
||||
|
||||
size_t data_size_;
|
||||
DISTFUNC <dist_t> fstdistfunc_;
|
||||
void *dist_func_param_;
|
||||
std::mutex index_lock;
|
||||
|
||||
std::unordered_map<labeltype, size_t > dict_external_to_internal;
|
||||
|
||||
|
||||
BruteforceSearch(SpaceInterface <dist_t> *s)
|
||||
: data_(nullptr),
|
||||
maxelements_(0),
|
||||
cur_element_count(0),
|
||||
size_per_element_(0),
|
||||
data_size_(0),
|
||||
dist_func_param_(nullptr) {
|
||||
}
|
||||
|
||||
|
||||
BruteforceSearch(SpaceInterface<dist_t> *s, const std::string &location)
|
||||
: data_(nullptr),
|
||||
maxelements_(0),
|
||||
cur_element_count(0),
|
||||
size_per_element_(0),
|
||||
data_size_(0),
|
||||
dist_func_param_(nullptr) {
|
||||
loadIndex(location, s);
|
||||
}
|
||||
|
||||
|
||||
BruteforceSearch(SpaceInterface <dist_t> *s, size_t maxElements) {
|
||||
maxelements_ = maxElements;
|
||||
data_size_ = s->get_data_size();
|
||||
fstdistfunc_ = s->get_dist_func();
|
||||
dist_func_param_ = s->get_dist_func_param();
|
||||
size_per_element_ = data_size_ + sizeof(labeltype);
|
||||
data_ = (char *) malloc(maxElements * size_per_element_);
|
||||
if (data_ == nullptr)
|
||||
throw std::runtime_error("Not enough memory: BruteforceSearch failed to allocate data");
|
||||
cur_element_count = 0;
|
||||
}
|
||||
|
||||
|
||||
~BruteforceSearch() {
|
||||
free(data_);
|
||||
}
|
||||
|
||||
|
||||
void addPoint(const void *datapoint, labeltype label, bool replace_deleted = false) {
|
||||
int idx;
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(index_lock);
|
||||
|
||||
auto search = dict_external_to_internal.find(label);
|
||||
if (search != dict_external_to_internal.end()) {
|
||||
idx = search->second;
|
||||
} else {
|
||||
if (cur_element_count >= maxelements_) {
|
||||
throw std::runtime_error("The number of elements exceeds the specified limit\n");
|
||||
}
|
||||
idx = cur_element_count;
|
||||
dict_external_to_internal[label] = idx;
|
||||
cur_element_count++;
|
||||
}
|
||||
}
|
||||
memcpy(data_ + size_per_element_ * idx + data_size_, &label, sizeof(labeltype));
|
||||
memcpy(data_ + size_per_element_ * idx, datapoint, data_size_);
|
||||
}
|
||||
|
||||
|
||||
void removePoint(labeltype cur_external) {
|
||||
size_t cur_c = dict_external_to_internal[cur_external];
|
||||
|
||||
dict_external_to_internal.erase(cur_external);
|
||||
|
||||
labeltype label = *((labeltype*)(data_ + size_per_element_ * (cur_element_count-1) + data_size_));
|
||||
dict_external_to_internal[label] = cur_c;
|
||||
memcpy(data_ + size_per_element_ * cur_c,
|
||||
data_ + size_per_element_ * (cur_element_count-1),
|
||||
data_size_+sizeof(labeltype));
|
||||
cur_element_count--;
|
||||
}
|
||||
|
||||
|
||||
std::priority_queue<std::pair<dist_t, labeltype >>
|
||||
searchKnn(const void *query_data, size_t k, BaseFilterFunctor* isIdAllowed = nullptr) const {
|
||||
assert(k <= cur_element_count);
|
||||
std::priority_queue<std::pair<dist_t, labeltype >> topResults;
|
||||
if (cur_element_count == 0) return topResults;
|
||||
for (int i = 0; i < k; i++) {
|
||||
dist_t dist = fstdistfunc_(query_data, data_ + size_per_element_ * i, dist_func_param_);
|
||||
labeltype label = *((labeltype*) (data_ + size_per_element_ * i + data_size_));
|
||||
if ((!isIdAllowed) || (*isIdAllowed)(label)) {
|
||||
topResults.push(std::pair<dist_t, labeltype>(dist, label));
|
||||
}
|
||||
}
|
||||
dist_t lastdist = topResults.empty() ? std::numeric_limits<dist_t>::max() : topResults.top().first;
|
||||
for (int i = k; i < cur_element_count; i++) {
|
||||
dist_t dist = fstdistfunc_(query_data, data_ + size_per_element_ * i, dist_func_param_);
|
||||
if (dist <= lastdist) {
|
||||
labeltype label = *((labeltype *) (data_ + size_per_element_ * i + data_size_));
|
||||
if ((!isIdAllowed) || (*isIdAllowed)(label)) {
|
||||
topResults.push(std::pair<dist_t, labeltype>(dist, label));
|
||||
}
|
||||
if (topResults.size() > k)
|
||||
topResults.pop();
|
||||
|
||||
if (!topResults.empty()) {
|
||||
lastdist = topResults.top().first;
|
||||
}
|
||||
}
|
||||
}
|
||||
return topResults;
|
||||
}
|
||||
|
||||
|
||||
void saveIndex(const std::string &location) {
|
||||
std::ofstream output(location, std::ios::binary);
|
||||
std::streampos position;
|
||||
|
||||
writeBinaryPOD(output, maxelements_);
|
||||
writeBinaryPOD(output, size_per_element_);
|
||||
writeBinaryPOD(output, cur_element_count);
|
||||
|
||||
output.write(data_, maxelements_ * size_per_element_);
|
||||
|
||||
output.close();
|
||||
}
|
||||
|
||||
|
||||
void loadIndex(const std::string &location, SpaceInterface<dist_t> *s) {
|
||||
std::ifstream input(location, std::ios::binary);
|
||||
std::streampos position;
|
||||
|
||||
readBinaryPOD(input, maxelements_);
|
||||
readBinaryPOD(input, size_per_element_);
|
||||
readBinaryPOD(input, cur_element_count);
|
||||
|
||||
data_size_ = s->get_data_size();
|
||||
fstdistfunc_ = s->get_dist_func();
|
||||
dist_func_param_ = s->get_dist_func_param();
|
||||
size_per_element_ = data_size_ + sizeof(labeltype);
|
||||
data_ = (char *) malloc(maxelements_ * size_per_element_);
|
||||
if (data_ == nullptr)
|
||||
throw std::runtime_error("Not enough memory: loadIndex failed to allocate data");
|
||||
|
||||
input.read(data_, maxelements_ * size_per_element_);
|
||||
|
||||
input.close();
|
||||
}
|
||||
};
|
||||
} // namespace hnswlib
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,199 @@
|
||||
#pragma once
|
||||
#ifndef NO_MANUAL_VECTORIZATION
|
||||
#if (defined(__SSE__) || _M_IX86_FP > 0 || defined(_M_AMD64) || defined(_M_X64))
|
||||
#define USE_SSE
|
||||
#ifdef __AVX__
|
||||
#define USE_AVX
|
||||
#ifdef __AVX512F__
|
||||
#define USE_AVX512
|
||||
#endif
|
||||
#endif
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#if defined(USE_AVX) || defined(USE_SSE)
|
||||
#ifdef _MSC_VER
|
||||
#include <intrin.h>
|
||||
#include <stdexcept>
|
||||
void cpuid(int32_t out[4], int32_t eax, int32_t ecx) {
|
||||
__cpuidex(out, eax, ecx);
|
||||
}
|
||||
static __int64 xgetbv(unsigned int x) {
|
||||
return _xgetbv(x);
|
||||
}
|
||||
#else
|
||||
#include <x86intrin.h>
|
||||
#include <cpuid.h>
|
||||
#include <stdint.h>
|
||||
static void cpuid(int32_t cpuInfo[4], int32_t eax, int32_t ecx) {
|
||||
__cpuid_count(eax, ecx, cpuInfo[0], cpuInfo[1], cpuInfo[2], cpuInfo[3]);
|
||||
}
|
||||
static uint64_t xgetbv(unsigned int index) {
|
||||
uint32_t eax, edx;
|
||||
__asm__ __volatile__("xgetbv" : "=a"(eax), "=d"(edx) : "c"(index));
|
||||
return ((uint64_t)edx << 32) | eax;
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(USE_AVX512)
|
||||
#include <immintrin.h>
|
||||
#endif
|
||||
|
||||
#if defined(__GNUC__)
|
||||
#define PORTABLE_ALIGN32 __attribute__((aligned(32)))
|
||||
#define PORTABLE_ALIGN64 __attribute__((aligned(64)))
|
||||
#else
|
||||
#define PORTABLE_ALIGN32 __declspec(align(32))
|
||||
#define PORTABLE_ALIGN64 __declspec(align(64))
|
||||
#endif
|
||||
|
||||
// Adapted from https://github.com/Mysticial/FeatureDetector
|
||||
#define _XCR_XFEATURE_ENABLED_MASK 0
|
||||
|
||||
static bool AVXCapable() {
|
||||
int cpuInfo[4];
|
||||
|
||||
// CPU support
|
||||
cpuid(cpuInfo, 0, 0);
|
||||
int nIds = cpuInfo[0];
|
||||
|
||||
bool HW_AVX = false;
|
||||
if (nIds >= 0x00000001) {
|
||||
cpuid(cpuInfo, 0x00000001, 0);
|
||||
HW_AVX = (cpuInfo[2] & ((int)1 << 28)) != 0;
|
||||
}
|
||||
|
||||
// OS support
|
||||
cpuid(cpuInfo, 1, 0);
|
||||
|
||||
bool osUsesXSAVE_XRSTORE = (cpuInfo[2] & (1 << 27)) != 0;
|
||||
bool cpuAVXSuport = (cpuInfo[2] & (1 << 28)) != 0;
|
||||
|
||||
bool avxSupported = false;
|
||||
if (osUsesXSAVE_XRSTORE && cpuAVXSuport) {
|
||||
uint64_t xcrFeatureMask = xgetbv(_XCR_XFEATURE_ENABLED_MASK);
|
||||
avxSupported = (xcrFeatureMask & 0x6) == 0x6;
|
||||
}
|
||||
return HW_AVX && avxSupported;
|
||||
}
|
||||
|
||||
static bool AVX512Capable() {
|
||||
if (!AVXCapable()) return false;
|
||||
|
||||
int cpuInfo[4];
|
||||
|
||||
// CPU support
|
||||
cpuid(cpuInfo, 0, 0);
|
||||
int nIds = cpuInfo[0];
|
||||
|
||||
bool HW_AVX512F = false;
|
||||
if (nIds >= 0x00000007) { // AVX512 Foundation
|
||||
cpuid(cpuInfo, 0x00000007, 0);
|
||||
HW_AVX512F = (cpuInfo[1] & ((int)1 << 16)) != 0;
|
||||
}
|
||||
|
||||
// OS support
|
||||
cpuid(cpuInfo, 1, 0);
|
||||
|
||||
bool osUsesXSAVE_XRSTORE = (cpuInfo[2] & (1 << 27)) != 0;
|
||||
bool cpuAVXSuport = (cpuInfo[2] & (1 << 28)) != 0;
|
||||
|
||||
bool avx512Supported = false;
|
||||
if (osUsesXSAVE_XRSTORE && cpuAVXSuport) {
|
||||
uint64_t xcrFeatureMask = xgetbv(_XCR_XFEATURE_ENABLED_MASK);
|
||||
avx512Supported = (xcrFeatureMask & 0xe6) == 0xe6;
|
||||
}
|
||||
return HW_AVX512F && avx512Supported;
|
||||
}
|
||||
#endif
|
||||
|
||||
#include <queue>
|
||||
#include <vector>
|
||||
#include <iostream>
|
||||
#include <string.h>
|
||||
|
||||
namespace hnswlib {
|
||||
typedef size_t labeltype;
|
||||
|
||||
// This can be extended to store state for filtering (e.g. from a std::set)
|
||||
class BaseFilterFunctor {
|
||||
public:
|
||||
virtual bool operator()(hnswlib::labeltype id) { return true; }
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class pairGreater {
|
||||
public:
|
||||
bool operator()(const T& p1, const T& p2) {
|
||||
return p1.first > p2.first;
|
||||
}
|
||||
};
|
||||
|
||||
template<typename T>
|
||||
static void writeBinaryPOD(std::ostream &out, const T &podRef) {
|
||||
out.write((char *) &podRef, sizeof(T));
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static void readBinaryPOD(std::istream &in, T &podRef) {
|
||||
in.read((char *) &podRef, sizeof(T));
|
||||
}
|
||||
|
||||
template<typename MTYPE>
|
||||
using DISTFUNC = MTYPE(*)(const void *, const void *, const void *);
|
||||
|
||||
template<typename MTYPE>
|
||||
class SpaceInterface {
|
||||
public:
|
||||
// virtual void search(void *);
|
||||
virtual size_t get_data_size() = 0;
|
||||
|
||||
virtual DISTFUNC<MTYPE> get_dist_func() = 0;
|
||||
|
||||
virtual void *get_dist_func_param() = 0;
|
||||
|
||||
virtual ~SpaceInterface() {}
|
||||
};
|
||||
|
||||
template<typename dist_t>
|
||||
class AlgorithmInterface {
|
||||
public:
|
||||
virtual void addPoint(const void *datapoint, labeltype label, bool replace_deleted = false) = 0;
|
||||
|
||||
virtual std::priority_queue<std::pair<dist_t, labeltype>>
|
||||
searchKnn(const void*, size_t, BaseFilterFunctor* isIdAllowed = nullptr) const = 0;
|
||||
|
||||
// Return k nearest neighbor in the order of closer fist
|
||||
virtual std::vector<std::pair<dist_t, labeltype>>
|
||||
searchKnnCloserFirst(const void* query_data, size_t k, BaseFilterFunctor* isIdAllowed = nullptr) const;
|
||||
|
||||
virtual void saveIndex(const std::string &location) = 0;
|
||||
virtual ~AlgorithmInterface(){
|
||||
}
|
||||
};
|
||||
|
||||
template<typename dist_t>
|
||||
std::vector<std::pair<dist_t, labeltype>>
|
||||
AlgorithmInterface<dist_t>::searchKnnCloserFirst(const void* query_data, size_t k,
|
||||
BaseFilterFunctor* isIdAllowed) const {
|
||||
std::vector<std::pair<dist_t, labeltype>> result;
|
||||
|
||||
// here searchKnn returns the result in the order of further first
|
||||
auto ret = searchKnn(query_data, k, isIdAllowed);
|
||||
{
|
||||
size_t sz = ret.size();
|
||||
result.resize(sz);
|
||||
while (!ret.empty()) {
|
||||
result[--sz] = ret.top();
|
||||
ret.pop();
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
} // namespace hnswlib
|
||||
|
||||
#include "space_l2.h"
|
||||
#include "space_ip.h"
|
||||
#include "bruteforce.h"
|
||||
#include "hnswalg.h"
|
@ -0,0 +1,375 @@
|
||||
#pragma once
|
||||
#include "hnswlib.h"
|
||||
|
||||
namespace hnswlib {
|
||||
|
||||
static float
|
||||
InnerProduct(const void *pVect1, const void *pVect2, const void *qty_ptr) {
|
||||
size_t qty = *((size_t *) qty_ptr);
|
||||
float res = 0;
|
||||
for (unsigned i = 0; i < qty; i++) {
|
||||
res += ((float *) pVect1)[i] * ((float *) pVect2)[i];
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
static float
|
||||
InnerProductDistance(const void *pVect1, const void *pVect2, const void *qty_ptr) {
|
||||
return 1.0f - InnerProduct(pVect1, pVect2, qty_ptr);
|
||||
}
|
||||
|
||||
#if defined(USE_AVX)
|
||||
|
||||
// Favor using AVX if available.
|
||||
static float
|
||||
InnerProductSIMD4ExtAVX(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
|
||||
float PORTABLE_ALIGN32 TmpRes[8];
|
||||
float *pVect1 = (float *) pVect1v;
|
||||
float *pVect2 = (float *) pVect2v;
|
||||
size_t qty = *((size_t *) qty_ptr);
|
||||
|
||||
size_t qty16 = qty / 16;
|
||||
size_t qty4 = qty / 4;
|
||||
|
||||
const float *pEnd1 = pVect1 + 16 * qty16;
|
||||
const float *pEnd2 = pVect1 + 4 * qty4;
|
||||
|
||||
__m256 sum256 = _mm256_set1_ps(0);
|
||||
|
||||
while (pVect1 < pEnd1) {
|
||||
//_mm_prefetch((char*)(pVect2 + 16), _MM_HINT_T0);
|
||||
|
||||
__m256 v1 = _mm256_loadu_ps(pVect1);
|
||||
pVect1 += 8;
|
||||
__m256 v2 = _mm256_loadu_ps(pVect2);
|
||||
pVect2 += 8;
|
||||
sum256 = _mm256_add_ps(sum256, _mm256_mul_ps(v1, v2));
|
||||
|
||||
v1 = _mm256_loadu_ps(pVect1);
|
||||
pVect1 += 8;
|
||||
v2 = _mm256_loadu_ps(pVect2);
|
||||
pVect2 += 8;
|
||||
sum256 = _mm256_add_ps(sum256, _mm256_mul_ps(v1, v2));
|
||||
}
|
||||
|
||||
__m128 v1, v2;
|
||||
__m128 sum_prod = _mm_add_ps(_mm256_extractf128_ps(sum256, 0), _mm256_extractf128_ps(sum256, 1));
|
||||
|
||||
while (pVect1 < pEnd2) {
|
||||
v1 = _mm_loadu_ps(pVect1);
|
||||
pVect1 += 4;
|
||||
v2 = _mm_loadu_ps(pVect2);
|
||||
pVect2 += 4;
|
||||
sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2));
|
||||
}
|
||||
|
||||
_mm_store_ps(TmpRes, sum_prod);
|
||||
float sum = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3];
|
||||
return sum;
|
||||
}
|
||||
|
||||
static float
|
||||
InnerProductDistanceSIMD4ExtAVX(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
|
||||
return 1.0f - InnerProductSIMD4ExtAVX(pVect1v, pVect2v, qty_ptr);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
#if defined(USE_SSE)
|
||||
|
||||
static float
|
||||
InnerProductSIMD4ExtSSE(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
|
||||
float PORTABLE_ALIGN32 TmpRes[8];
|
||||
float *pVect1 = (float *) pVect1v;
|
||||
float *pVect2 = (float *) pVect2v;
|
||||
size_t qty = *((size_t *) qty_ptr);
|
||||
|
||||
size_t qty16 = qty / 16;
|
||||
size_t qty4 = qty / 4;
|
||||
|
||||
const float *pEnd1 = pVect1 + 16 * qty16;
|
||||
const float *pEnd2 = pVect1 + 4 * qty4;
|
||||
|
||||
__m128 v1, v2;
|
||||
__m128 sum_prod = _mm_set1_ps(0);
|
||||
|
||||
while (pVect1 < pEnd1) {
|
||||
v1 = _mm_loadu_ps(pVect1);
|
||||
pVect1 += 4;
|
||||
v2 = _mm_loadu_ps(pVect2);
|
||||
pVect2 += 4;
|
||||
sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2));
|
||||
|
||||
v1 = _mm_loadu_ps(pVect1);
|
||||
pVect1 += 4;
|
||||
v2 = _mm_loadu_ps(pVect2);
|
||||
pVect2 += 4;
|
||||
sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2));
|
||||
|
||||
v1 = _mm_loadu_ps(pVect1);
|
||||
pVect1 += 4;
|
||||
v2 = _mm_loadu_ps(pVect2);
|
||||
pVect2 += 4;
|
||||
sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2));
|
||||
|
||||
v1 = _mm_loadu_ps(pVect1);
|
||||
pVect1 += 4;
|
||||
v2 = _mm_loadu_ps(pVect2);
|
||||
pVect2 += 4;
|
||||
sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2));
|
||||
}
|
||||
|
||||
while (pVect1 < pEnd2) {
|
||||
v1 = _mm_loadu_ps(pVect1);
|
||||
pVect1 += 4;
|
||||
v2 = _mm_loadu_ps(pVect2);
|
||||
pVect2 += 4;
|
||||
sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2));
|
||||
}
|
||||
|
||||
_mm_store_ps(TmpRes, sum_prod);
|
||||
float sum = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3];
|
||||
|
||||
return sum;
|
||||
}
|
||||
|
||||
static float
|
||||
InnerProductDistanceSIMD4ExtSSE(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
|
||||
return 1.0f - InnerProductSIMD4ExtSSE(pVect1v, pVect2v, qty_ptr);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
|
||||
#if defined(USE_AVX512)
|
||||
|
||||
static float
|
||||
InnerProductSIMD16ExtAVX512(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
|
||||
float PORTABLE_ALIGN64 TmpRes[16];
|
||||
float *pVect1 = (float *) pVect1v;
|
||||
float *pVect2 = (float *) pVect2v;
|
||||
size_t qty = *((size_t *) qty_ptr);
|
||||
|
||||
size_t qty16 = qty / 16;
|
||||
|
||||
|
||||
const float *pEnd1 = pVect1 + 16 * qty16;
|
||||
|
||||
__m512 sum512 = _mm512_set1_ps(0);
|
||||
|
||||
while (pVect1 < pEnd1) {
|
||||
//_mm_prefetch((char*)(pVect2 + 16), _MM_HINT_T0);
|
||||
|
||||
__m512 v1 = _mm512_loadu_ps(pVect1);
|
||||
pVect1 += 16;
|
||||
__m512 v2 = _mm512_loadu_ps(pVect2);
|
||||
pVect2 += 16;
|
||||
sum512 = _mm512_add_ps(sum512, _mm512_mul_ps(v1, v2));
|
||||
}
|
||||
|
||||
_mm512_store_ps(TmpRes, sum512);
|
||||
float sum = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3] + TmpRes[4] + TmpRes[5] + TmpRes[6] + TmpRes[7] + TmpRes[8] + TmpRes[9] + TmpRes[10] + TmpRes[11] + TmpRes[12] + TmpRes[13] + TmpRes[14] + TmpRes[15];
|
||||
|
||||
return sum;
|
||||
}
|
||||
|
||||
static float
|
||||
InnerProductDistanceSIMD16ExtAVX512(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
|
||||
return 1.0f - InnerProductSIMD16ExtAVX512(pVect1v, pVect2v, qty_ptr);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
#if defined(USE_AVX)
|
||||
|
||||
static float
|
||||
InnerProductSIMD16ExtAVX(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
|
||||
float PORTABLE_ALIGN32 TmpRes[8];
|
||||
float *pVect1 = (float *) pVect1v;
|
||||
float *pVect2 = (float *) pVect2v;
|
||||
size_t qty = *((size_t *) qty_ptr);
|
||||
|
||||
size_t qty16 = qty / 16;
|
||||
|
||||
|
||||
const float *pEnd1 = pVect1 + 16 * qty16;
|
||||
|
||||
__m256 sum256 = _mm256_set1_ps(0);
|
||||
|
||||
while (pVect1 < pEnd1) {
|
||||
//_mm_prefetch((char*)(pVect2 + 16), _MM_HINT_T0);
|
||||
|
||||
__m256 v1 = _mm256_loadu_ps(pVect1);
|
||||
pVect1 += 8;
|
||||
__m256 v2 = _mm256_loadu_ps(pVect2);
|
||||
pVect2 += 8;
|
||||
sum256 = _mm256_add_ps(sum256, _mm256_mul_ps(v1, v2));
|
||||
|
||||
v1 = _mm256_loadu_ps(pVect1);
|
||||
pVect1 += 8;
|
||||
v2 = _mm256_loadu_ps(pVect2);
|
||||
pVect2 += 8;
|
||||
sum256 = _mm256_add_ps(sum256, _mm256_mul_ps(v1, v2));
|
||||
}
|
||||
|
||||
_mm256_store_ps(TmpRes, sum256);
|
||||
float sum = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3] + TmpRes[4] + TmpRes[5] + TmpRes[6] + TmpRes[7];
|
||||
|
||||
return sum;
|
||||
}
|
||||
|
||||
static float
|
||||
InnerProductDistanceSIMD16ExtAVX(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
|
||||
return 1.0f - InnerProductSIMD16ExtAVX(pVect1v, pVect2v, qty_ptr);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
#if defined(USE_SSE)
|
||||
|
||||
static float
|
||||
InnerProductSIMD16ExtSSE(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
|
||||
float PORTABLE_ALIGN32 TmpRes[8];
|
||||
float *pVect1 = (float *) pVect1v;
|
||||
float *pVect2 = (float *) pVect2v;
|
||||
size_t qty = *((size_t *) qty_ptr);
|
||||
|
||||
size_t qty16 = qty / 16;
|
||||
|
||||
const float *pEnd1 = pVect1 + 16 * qty16;
|
||||
|
||||
__m128 v1, v2;
|
||||
__m128 sum_prod = _mm_set1_ps(0);
|
||||
|
||||
while (pVect1 < pEnd1) {
|
||||
v1 = _mm_loadu_ps(pVect1);
|
||||
pVect1 += 4;
|
||||
v2 = _mm_loadu_ps(pVect2);
|
||||
pVect2 += 4;
|
||||
sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2));
|
||||
|
||||
v1 = _mm_loadu_ps(pVect1);
|
||||
pVect1 += 4;
|
||||
v2 = _mm_loadu_ps(pVect2);
|
||||
pVect2 += 4;
|
||||
sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2));
|
||||
|
||||
v1 = _mm_loadu_ps(pVect1);
|
||||
pVect1 += 4;
|
||||
v2 = _mm_loadu_ps(pVect2);
|
||||
pVect2 += 4;
|
||||
sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2));
|
||||
|
||||
v1 = _mm_loadu_ps(pVect1);
|
||||
pVect1 += 4;
|
||||
v2 = _mm_loadu_ps(pVect2);
|
||||
pVect2 += 4;
|
||||
sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2));
|
||||
}
|
||||
_mm_store_ps(TmpRes, sum_prod);
|
||||
float sum = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3];
|
||||
|
||||
return sum;
|
||||
}
|
||||
|
||||
static float
|
||||
InnerProductDistanceSIMD16ExtSSE(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
|
||||
return 1.0f - InnerProductSIMD16ExtSSE(pVect1v, pVect2v, qty_ptr);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
#if defined(USE_SSE) || defined(USE_AVX) || defined(USE_AVX512)
|
||||
static DISTFUNC<float> InnerProductSIMD16Ext = InnerProductSIMD16ExtSSE;
|
||||
static DISTFUNC<float> InnerProductSIMD4Ext = InnerProductSIMD4ExtSSE;
|
||||
static DISTFUNC<float> InnerProductDistanceSIMD16Ext = InnerProductDistanceSIMD16ExtSSE;
|
||||
static DISTFUNC<float> InnerProductDistanceSIMD4Ext = InnerProductDistanceSIMD4ExtSSE;
|
||||
|
||||
static float
|
||||
InnerProductDistanceSIMD16ExtResiduals(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
|
||||
size_t qty = *((size_t *) qty_ptr);
|
||||
size_t qty16 = qty >> 4 << 4;
|
||||
float res = InnerProductSIMD16Ext(pVect1v, pVect2v, &qty16);
|
||||
float *pVect1 = (float *) pVect1v + qty16;
|
||||
float *pVect2 = (float *) pVect2v + qty16;
|
||||
|
||||
size_t qty_left = qty - qty16;
|
||||
float res_tail = InnerProduct(pVect1, pVect2, &qty_left);
|
||||
return 1.0f - (res + res_tail);
|
||||
}
|
||||
|
||||
static float
|
||||
InnerProductDistanceSIMD4ExtResiduals(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
|
||||
size_t qty = *((size_t *) qty_ptr);
|
||||
size_t qty4 = qty >> 2 << 2;
|
||||
|
||||
float res = InnerProductSIMD4Ext(pVect1v, pVect2v, &qty4);
|
||||
size_t qty_left = qty - qty4;
|
||||
|
||||
float *pVect1 = (float *) pVect1v + qty4;
|
||||
float *pVect2 = (float *) pVect2v + qty4;
|
||||
float res_tail = InnerProduct(pVect1, pVect2, &qty_left);
|
||||
|
||||
return 1.0f - (res + res_tail);
|
||||
}
|
||||
#endif
|
||||
|
||||
class InnerProductSpace : public SpaceInterface<float> {
|
||||
DISTFUNC<float> fstdistfunc_;
|
||||
size_t data_size_;
|
||||
size_t dim_;
|
||||
|
||||
public:
|
||||
InnerProductSpace(size_t dim) {
|
||||
fstdistfunc_ = InnerProductDistance;
|
||||
#if defined(USE_AVX) || defined(USE_SSE) || defined(USE_AVX512)
|
||||
#if defined(USE_AVX512)
|
||||
if (AVX512Capable()) {
|
||||
InnerProductSIMD16Ext = InnerProductSIMD16ExtAVX512;
|
||||
InnerProductDistanceSIMD16Ext = InnerProductDistanceSIMD16ExtAVX512;
|
||||
} else if (AVXCapable()) {
|
||||
InnerProductSIMD16Ext = InnerProductSIMD16ExtAVX;
|
||||
InnerProductDistanceSIMD16Ext = InnerProductDistanceSIMD16ExtAVX;
|
||||
}
|
||||
#elif defined(USE_AVX)
|
||||
if (AVXCapable()) {
|
||||
InnerProductSIMD16Ext = InnerProductSIMD16ExtAVX;
|
||||
InnerProductDistanceSIMD16Ext = InnerProductDistanceSIMD16ExtAVX;
|
||||
}
|
||||
#endif
|
||||
#if defined(USE_AVX)
|
||||
if (AVXCapable()) {
|
||||
InnerProductSIMD4Ext = InnerProductSIMD4ExtAVX;
|
||||
InnerProductDistanceSIMD4Ext = InnerProductDistanceSIMD4ExtAVX;
|
||||
}
|
||||
#endif
|
||||
|
||||
if (dim % 16 == 0)
|
||||
fstdistfunc_ = InnerProductDistanceSIMD16Ext;
|
||||
else if (dim % 4 == 0)
|
||||
fstdistfunc_ = InnerProductDistanceSIMD4Ext;
|
||||
else if (dim > 16)
|
||||
fstdistfunc_ = InnerProductDistanceSIMD16ExtResiduals;
|
||||
else if (dim > 4)
|
||||
fstdistfunc_ = InnerProductDistanceSIMD4ExtResiduals;
|
||||
#endif
|
||||
dim_ = dim;
|
||||
data_size_ = dim * sizeof(float);
|
||||
}
|
||||
|
||||
size_t get_data_size() {
|
||||
return data_size_;
|
||||
}
|
||||
|
||||
DISTFUNC<float> get_dist_func() {
|
||||
return fstdistfunc_;
|
||||
}
|
||||
|
||||
void *get_dist_func_param() {
|
||||
return &dim_;
|
||||
}
|
||||
|
||||
~InnerProductSpace() {}
|
||||
};
|
||||
|
||||
} // namespace hnswlib
|
@ -0,0 +1,324 @@
|
||||
#pragma once
|
||||
#include "hnswlib.h"
|
||||
|
||||
namespace hnswlib {
|
||||
|
||||
static float
|
||||
L2Sqr(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
|
||||
float *pVect1 = (float *) pVect1v;
|
||||
float *pVect2 = (float *) pVect2v;
|
||||
size_t qty = *((size_t *) qty_ptr);
|
||||
|
||||
float res = 0;
|
||||
for (size_t i = 0; i < qty; i++) {
|
||||
float t = *pVect1 - *pVect2;
|
||||
pVect1++;
|
||||
pVect2++;
|
||||
res += t * t;
|
||||
}
|
||||
return (res);
|
||||
}
|
||||
|
||||
#if defined(USE_AVX512)
|
||||
|
||||
// Favor using AVX512 if available.
|
||||
static float
|
||||
L2SqrSIMD16ExtAVX512(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
|
||||
float *pVect1 = (float *) pVect1v;
|
||||
float *pVect2 = (float *) pVect2v;
|
||||
size_t qty = *((size_t *) qty_ptr);
|
||||
float PORTABLE_ALIGN64 TmpRes[16];
|
||||
size_t qty16 = qty >> 4;
|
||||
|
||||
const float *pEnd1 = pVect1 + (qty16 << 4);
|
||||
|
||||
__m512 diff, v1, v2;
|
||||
__m512 sum = _mm512_set1_ps(0);
|
||||
|
||||
while (pVect1 < pEnd1) {
|
||||
v1 = _mm512_loadu_ps(pVect1);
|
||||
pVect1 += 16;
|
||||
v2 = _mm512_loadu_ps(pVect2);
|
||||
pVect2 += 16;
|
||||
diff = _mm512_sub_ps(v1, v2);
|
||||
// sum = _mm512_fmadd_ps(diff, diff, sum);
|
||||
sum = _mm512_add_ps(sum, _mm512_mul_ps(diff, diff));
|
||||
}
|
||||
|
||||
_mm512_store_ps(TmpRes, sum);
|
||||
float res = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3] + TmpRes[4] + TmpRes[5] + TmpRes[6] +
|
||||
TmpRes[7] + TmpRes[8] + TmpRes[9] + TmpRes[10] + TmpRes[11] + TmpRes[12] +
|
||||
TmpRes[13] + TmpRes[14] + TmpRes[15];
|
||||
|
||||
return (res);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(USE_AVX)
|
||||
|
||||
// Favor using AVX if available.
|
||||
static float
|
||||
L2SqrSIMD16ExtAVX(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
|
||||
float *pVect1 = (float *) pVect1v;
|
||||
float *pVect2 = (float *) pVect2v;
|
||||
size_t qty = *((size_t *) qty_ptr);
|
||||
float PORTABLE_ALIGN32 TmpRes[8];
|
||||
size_t qty16 = qty >> 4;
|
||||
|
||||
const float *pEnd1 = pVect1 + (qty16 << 4);
|
||||
|
||||
__m256 diff, v1, v2;
|
||||
__m256 sum = _mm256_set1_ps(0);
|
||||
|
||||
while (pVect1 < pEnd1) {
|
||||
v1 = _mm256_loadu_ps(pVect1);
|
||||
pVect1 += 8;
|
||||
v2 = _mm256_loadu_ps(pVect2);
|
||||
pVect2 += 8;
|
||||
diff = _mm256_sub_ps(v1, v2);
|
||||
sum = _mm256_add_ps(sum, _mm256_mul_ps(diff, diff));
|
||||
|
||||
v1 = _mm256_loadu_ps(pVect1);
|
||||
pVect1 += 8;
|
||||
v2 = _mm256_loadu_ps(pVect2);
|
||||
pVect2 += 8;
|
||||
diff = _mm256_sub_ps(v1, v2);
|
||||
sum = _mm256_add_ps(sum, _mm256_mul_ps(diff, diff));
|
||||
}
|
||||
|
||||
_mm256_store_ps(TmpRes, sum);
|
||||
return TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3] + TmpRes[4] + TmpRes[5] + TmpRes[6] + TmpRes[7];
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
#if defined(USE_SSE)
|
||||
|
||||
static float
|
||||
L2SqrSIMD16ExtSSE(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
|
||||
float *pVect1 = (float *) pVect1v;
|
||||
float *pVect2 = (float *) pVect2v;
|
||||
size_t qty = *((size_t *) qty_ptr);
|
||||
float PORTABLE_ALIGN32 TmpRes[8];
|
||||
size_t qty16 = qty >> 4;
|
||||
|
||||
const float *pEnd1 = pVect1 + (qty16 << 4);
|
||||
|
||||
__m128 diff, v1, v2;
|
||||
__m128 sum = _mm_set1_ps(0);
|
||||
|
||||
while (pVect1 < pEnd1) {
|
||||
//_mm_prefetch((char*)(pVect2 + 16), _MM_HINT_T0);
|
||||
v1 = _mm_loadu_ps(pVect1);
|
||||
pVect1 += 4;
|
||||
v2 = _mm_loadu_ps(pVect2);
|
||||
pVect2 += 4;
|
||||
diff = _mm_sub_ps(v1, v2);
|
||||
sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff));
|
||||
|
||||
v1 = _mm_loadu_ps(pVect1);
|
||||
pVect1 += 4;
|
||||
v2 = _mm_loadu_ps(pVect2);
|
||||
pVect2 += 4;
|
||||
diff = _mm_sub_ps(v1, v2);
|
||||
sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff));
|
||||
|
||||
v1 = _mm_loadu_ps(pVect1);
|
||||
pVect1 += 4;
|
||||
v2 = _mm_loadu_ps(pVect2);
|
||||
pVect2 += 4;
|
||||
diff = _mm_sub_ps(v1, v2);
|
||||
sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff));
|
||||
|
||||
v1 = _mm_loadu_ps(pVect1);
|
||||
pVect1 += 4;
|
||||
v2 = _mm_loadu_ps(pVect2);
|
||||
pVect2 += 4;
|
||||
diff = _mm_sub_ps(v1, v2);
|
||||
sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff));
|
||||
}
|
||||
|
||||
_mm_store_ps(TmpRes, sum);
|
||||
return TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3];
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(USE_SSE) || defined(USE_AVX) || defined(USE_AVX512)
|
||||
static DISTFUNC<float> L2SqrSIMD16Ext = L2SqrSIMD16ExtSSE;
|
||||
|
||||
static float
|
||||
L2SqrSIMD16ExtResiduals(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
|
||||
size_t qty = *((size_t *) qty_ptr);
|
||||
size_t qty16 = qty >> 4 << 4;
|
||||
float res = L2SqrSIMD16Ext(pVect1v, pVect2v, &qty16);
|
||||
float *pVect1 = (float *) pVect1v + qty16;
|
||||
float *pVect2 = (float *) pVect2v + qty16;
|
||||
|
||||
size_t qty_left = qty - qty16;
|
||||
float res_tail = L2Sqr(pVect1, pVect2, &qty_left);
|
||||
return (res + res_tail);
|
||||
}
|
||||
#endif
|
||||
|
||||
|
||||
#if defined(USE_SSE)
|
||||
static float
|
||||
L2SqrSIMD4Ext(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
|
||||
float PORTABLE_ALIGN32 TmpRes[8];
|
||||
float *pVect1 = (float *) pVect1v;
|
||||
float *pVect2 = (float *) pVect2v;
|
||||
size_t qty = *((size_t *) qty_ptr);
|
||||
|
||||
|
||||
size_t qty4 = qty >> 2;
|
||||
|
||||
const float *pEnd1 = pVect1 + (qty4 << 2);
|
||||
|
||||
__m128 diff, v1, v2;
|
||||
__m128 sum = _mm_set1_ps(0);
|
||||
|
||||
while (pVect1 < pEnd1) {
|
||||
v1 = _mm_loadu_ps(pVect1);
|
||||
pVect1 += 4;
|
||||
v2 = _mm_loadu_ps(pVect2);
|
||||
pVect2 += 4;
|
||||
diff = _mm_sub_ps(v1, v2);
|
||||
sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff));
|
||||
}
|
||||
_mm_store_ps(TmpRes, sum);
|
||||
return TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3];
|
||||
}
|
||||
|
||||
static float
|
||||
L2SqrSIMD4ExtResiduals(const void *pVect1v, const void *pVect2v, const void *qty_ptr) {
|
||||
size_t qty = *((size_t *) qty_ptr);
|
||||
size_t qty4 = qty >> 2 << 2;
|
||||
|
||||
float res = L2SqrSIMD4Ext(pVect1v, pVect2v, &qty4);
|
||||
size_t qty_left = qty - qty4;
|
||||
|
||||
float *pVect1 = (float *) pVect1v + qty4;
|
||||
float *pVect2 = (float *) pVect2v + qty4;
|
||||
float res_tail = L2Sqr(pVect1, pVect2, &qty_left);
|
||||
|
||||
return (res + res_tail);
|
||||
}
|
||||
#endif
|
||||
|
||||
class L2Space : public SpaceInterface<float> {
|
||||
DISTFUNC<float> fstdistfunc_;
|
||||
size_t data_size_;
|
||||
size_t dim_;
|
||||
|
||||
public:
|
||||
L2Space(size_t dim) {
|
||||
fstdistfunc_ = L2Sqr;
|
||||
#if defined(USE_SSE) || defined(USE_AVX) || defined(USE_AVX512)
|
||||
#if defined(USE_AVX512)
|
||||
if (AVX512Capable())
|
||||
L2SqrSIMD16Ext = L2SqrSIMD16ExtAVX512;
|
||||
else if (AVXCapable())
|
||||
L2SqrSIMD16Ext = L2SqrSIMD16ExtAVX;
|
||||
#elif defined(USE_AVX)
|
||||
if (AVXCapable())
|
||||
L2SqrSIMD16Ext = L2SqrSIMD16ExtAVX;
|
||||
#endif
|
||||
|
||||
if (dim % 16 == 0)
|
||||
fstdistfunc_ = L2SqrSIMD16Ext;
|
||||
else if (dim % 4 == 0)
|
||||
fstdistfunc_ = L2SqrSIMD4Ext;
|
||||
else if (dim > 16)
|
||||
fstdistfunc_ = L2SqrSIMD16ExtResiduals;
|
||||
else if (dim > 4)
|
||||
fstdistfunc_ = L2SqrSIMD4ExtResiduals;
|
||||
#endif
|
||||
dim_ = dim;
|
||||
data_size_ = dim * sizeof(float);
|
||||
}
|
||||
|
||||
size_t get_data_size() {
|
||||
return data_size_;
|
||||
}
|
||||
|
||||
DISTFUNC<float> get_dist_func() {
|
||||
return fstdistfunc_;
|
||||
}
|
||||
|
||||
void *get_dist_func_param() {
|
||||
return &dim_;
|
||||
}
|
||||
|
||||
~L2Space() {}
|
||||
};
|
||||
|
||||
static int
|
||||
L2SqrI4x(const void *__restrict pVect1, const void *__restrict pVect2, const void *__restrict qty_ptr) {
|
||||
size_t qty = *((size_t *) qty_ptr);
|
||||
int res = 0;
|
||||
unsigned char *a = (unsigned char *) pVect1;
|
||||
unsigned char *b = (unsigned char *) pVect2;
|
||||
|
||||
qty = qty >> 2;
|
||||
for (size_t i = 0; i < qty; i++) {
|
||||
res += ((*a) - (*b)) * ((*a) - (*b));
|
||||
a++;
|
||||
b++;
|
||||
res += ((*a) - (*b)) * ((*a) - (*b));
|
||||
a++;
|
||||
b++;
|
||||
res += ((*a) - (*b)) * ((*a) - (*b));
|
||||
a++;
|
||||
b++;
|
||||
res += ((*a) - (*b)) * ((*a) - (*b));
|
||||
a++;
|
||||
b++;
|
||||
}
|
||||
return (res);
|
||||
}
|
||||
|
||||
static int L2SqrI(const void* __restrict pVect1, const void* __restrict pVect2, const void* __restrict qty_ptr) {
|
||||
size_t qty = *((size_t*)qty_ptr);
|
||||
int res = 0;
|
||||
unsigned char* a = (unsigned char*)pVect1;
|
||||
unsigned char* b = (unsigned char*)pVect2;
|
||||
|
||||
for (size_t i = 0; i < qty; i++) {
|
||||
res += ((*a) - (*b)) * ((*a) - (*b));
|
||||
a++;
|
||||
b++;
|
||||
}
|
||||
return (res);
|
||||
}
|
||||
|
||||
class L2SpaceI : public SpaceInterface<int> {
|
||||
DISTFUNC<int> fstdistfunc_;
|
||||
size_t data_size_;
|
||||
size_t dim_;
|
||||
|
||||
public:
|
||||
L2SpaceI(size_t dim) {
|
||||
if (dim % 4 == 0) {
|
||||
fstdistfunc_ = L2SqrI4x;
|
||||
} else {
|
||||
fstdistfunc_ = L2SqrI;
|
||||
}
|
||||
dim_ = dim;
|
||||
data_size_ = dim * sizeof(unsigned char);
|
||||
}
|
||||
|
||||
size_t get_data_size() {
|
||||
return data_size_;
|
||||
}
|
||||
|
||||
DISTFUNC<int> get_dist_func() {
|
||||
return fstdistfunc_;
|
||||
}
|
||||
|
||||
void *get_dist_func_param() {
|
||||
return &dim_;
|
||||
}
|
||||
|
||||
~L2SpaceI() {}
|
||||
};
|
||||
} // namespace hnswlib
|
@ -0,0 +1,78 @@
|
||||
#pragma once
|
||||
|
||||
#include <mutex>
|
||||
#include <string.h>
|
||||
#include <deque>
|
||||
|
||||
namespace hnswlib {
|
||||
typedef unsigned short int vl_type;
|
||||
|
||||
class VisitedList {
|
||||
public:
|
||||
vl_type curV;
|
||||
vl_type *mass;
|
||||
unsigned int numelements;
|
||||
|
||||
VisitedList(int numelements1) {
|
||||
curV = -1;
|
||||
numelements = numelements1;
|
||||
mass = new vl_type[numelements];
|
||||
}
|
||||
|
||||
void reset() {
|
||||
curV++;
|
||||
if (curV == 0) {
|
||||
memset(mass, 0, sizeof(vl_type) * numelements);
|
||||
curV++;
|
||||
}
|
||||
}
|
||||
|
||||
~VisitedList() { delete[] mass; }
|
||||
};
|
||||
///////////////////////////////////////////////////////////
|
||||
//
|
||||
// Class for multi-threaded pool-management of VisitedLists
|
||||
//
|
||||
/////////////////////////////////////////////////////////
|
||||
|
||||
class VisitedListPool {
|
||||
std::deque<VisitedList *> pool;
|
||||
std::mutex poolguard;
|
||||
int numelements;
|
||||
|
||||
public:
|
||||
VisitedListPool(int initmaxpools, int numelements1) {
|
||||
numelements = numelements1;
|
||||
for (int i = 0; i < initmaxpools; i++)
|
||||
pool.push_front(new VisitedList(numelements));
|
||||
}
|
||||
|
||||
VisitedList *getFreeVisitedList() {
|
||||
VisitedList *rez;
|
||||
{
|
||||
std::unique_lock <std::mutex> lock(poolguard);
|
||||
if (pool.size() > 0) {
|
||||
rez = pool.front();
|
||||
pool.pop_front();
|
||||
} else {
|
||||
rez = new VisitedList(numelements);
|
||||
}
|
||||
}
|
||||
rez->reset();
|
||||
return rez;
|
||||
}
|
||||
|
||||
void releaseVisitedList(VisitedList *vl) {
|
||||
std::unique_lock <std::mutex> lock(poolguard);
|
||||
pool.push_front(vl);
|
||||
}
|
||||
|
||||
~VisitedListPool() {
|
||||
while (pool.size()) {
|
||||
VisitedList *rez = pool.front();
|
||||
pool.pop_front();
|
||||
delete rez;
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace hnswlib
|
Loading…
Reference in New Issue