cat articles/svm-embedding-retriever
Similar embedding search with SVM: an alternative to kNN
LangChain v0.0.141 added an implementation called SVM Retriever. It finds the top-K embeddings similar to a single query embedding from a set of embeddings by using SVM. I wondered how that worked, looked into it, and found it interesting because it used an idea I did not know. This is a note about that.
kNN vs SVM
There is a notebook called knn_vs_svm.ipynb, which this implementation is based on. Machine-translating the beginning gives the following:
A common workflow is to index some data based on embeddings, then, given a new query embedding, use k-Nearest Neighbor search to retrieve the most similar examples. For example, you could imagine embedding a large collection of papers based on their abstracts, then giving a new paper of interest and retrieving the most similar papers.
In my experience, if you have a little extra compute budget, using SVM instead of kNN always works better. Here is an example:
k-nearest neighbors computes using Euclidean distance, but this approach uses SVM. The way SVM is used is interesting. Quoting from the notebook:
# Wired: use an SVM
from sklearn import svm
# create the "Dataset"
x = np.concatenate([query[None,...], embeddings]) # x is (1001, 1536) array, with query now as the first row
y = np.zeros(1001)
y[0] = 1 # we have a single positive example, mark it as such
# train our (Exemplar) SVM
# docs: https://scikit-learn.org/stable/modules/generated/sklearn.svm.LinearSVC.html
clf = svm.LinearSVC(class_weight='balanced', verbose=False, max_iter=10000, tol=1e-6, C=0.1)
clf.fit(x, y) # train
# infer on whatever data you wish, e.g. the original data
similarities = clf.decision_function(x)
sorted_ix = np.argsort(-similarities)
print("top 10 results:")
for k in sorted_ix[:10]:
print(f"row {k}, similarity {similarities[k]}")
It labels only the target single embedding as 1 and all others as 0, then trains LinearSVC as a classification task. The confidence score is obtained as a value around 1 to -1, and the top-K items closest to 1 are treated as the most similar items.
Instead of simple Euclidean distance, the score is computed while considering a space learned by SVM, so it does seem likely to produce better results. Using LinearSVC to compute that made me think, "I see." SVM Retriever is an abstraction that makes this method convenient to use in LangChain.
Comparing kNN and SVM results
Let's use 450 Japanese items from AI News and compare search results from kNN and SVM for a specific query.
query: 生成AIと著作権
=== kNN ===
0.886: 生成AIの猛烈な進化と著作権制度~技術発展と著作権者の利益のバランスをとるには~ | STORIA法律事務所
0.880: スター・ウォーズやハリポタの人気キャラと話せるAIの「著作権問題」をどう考えるべきか | シリコンバレーの「生き字引」がズバリ指摘 | クーリエ・ジャポン
0.876: 生成AIの利用ガイドライン作成のための手引き | STORIA法律事務所
0.876: ダブスタクソイナゴは生成AIの法的議論に参加してくるんじゃねえ!!
0.874: 画像生成AI “クリエーターの権利脅かされる” 法整備など提言 | NHK | AI(人工知能)
0.870: 【AI】生成AIを利用する場合に気を付けなければならない著作権の知識|福岡真之介|note
0.868: AIイラストに規制を求める団体の理事「木目百二」氏が二次創作のガイドライン違反で支援サイトの作品全消し、謝罪に追い込まれる - Togetter
0.865: 生成AI「開発規制、望ましくない」 松本総務相 - 日本経済新聞
=== SVM ===
-0.305: 生成AIで作品、それって著作権侵害? 福井健策弁護士に聞く:朝日新聞デジタル
-0.384: 生成AIの猛烈な進化と著作権制度~技術発展と著作権者の利益のバランスをとるには~ | STORIA法律事務所
-0.402: ダブスタクソイナゴは生成AIの法的議論に参加してくるんじゃねえ!!
-0.408: AIイラストに規制を求める団体の理事「木目百二」氏が二次創作のガイドライン違反で支援サイトの作品全消し、謝罪に追い込まれる - Togetter
-0.436: 画像生成AIによる作品の無許可使用を主張した写真家が逆に損害賠償を請求される - GIGAZINE
-0.479: アーティストのGrimes、生成AIで自分の声を自由に使っていいとツイート - ITmedia NEWS
-0.482: 生成AIの利用ガイドライン作成のための手引き | STORIA法律事務所
-0.483: スター・ウォーズやハリポタの人気キャラと話せるAIの「著作権問題」をどう考えるべきか | シリコンバレーの「生き字引」がズバリ指摘 | クーリエ・ジャポン
For the result above, both look reasonable at a glance. Let's try a slightly more difficult query.
query: 大規模言語モデルを低スペックのマシンで動かしたい
=== kNN ===
0.872: RWKV14Bを日本語AlpacaデータセットでLoRAして公開しました(ご家庭で動く!?)|shi3z|note
0.861: チャットAIをブラウザのWebGPUだけで実行でき日本語も使用できる「Web LLM」、実際に試してみる方法はこんな感じ - GIGAZINE
0.855: LLMをアプリ開発に統合するSDK「Semantic Kernel」がPythonに対応、TypeScriptへの対応も検討中|CodeZine(コードジン)
0.853: ChatGPT対抗のオープンソース言語モデル「StableLM」。日本語版も? - PC Watch
0.851: “画像の面白さ”を解説できるAI「MiniGPT-4」 写真からラップや詩、料理レシピ作成 デモサイトも公開中:Innovative Tech(1/2 ページ) - ITmedia NEWS
0.850: チャットAI「StableLM」発表 オープンソースモデルで商用可 「Stable Diffusion」開発元から - ITmedia NEWS
0.849: Googleの大規模言語モデル「Bard」、日本でも利用可能に。英語のみだが、改良されたPaLMベース | テクノエッジ TechnoEdge
0.849: Stability AIがオープンソースで商用利用も可能な大規模言語モデル「StableLM」をリリース - GIGAZINE
=== SVM ===
-0.359: 大規模言語モデルを自社でトレーニング&活用する方法|mah_lab / 西見 公宏|note
-0.366: 大規模言語モデル間の性能比較まとめ|mah_lab / 西見 公宏|note
-0.451: 深層学習コンパイラスタックと最適化
-0.456: LLMをアプリ開発に統合するSDK「Semantic Kernel」がPythonに対応、TypeScriptへの対応も検討中|CodeZine(コードジン)
-0.471: dolly-v2-12bという120億パラメータの言語モデルを使ってみた!|Masayuki Abe|note
-0.490: Googleの大規模言語モデル「Bard」、日本でも利用可能に。英語のみだが、改良されたPaLMベース | テクノエッジ TechnoEdge
-0.504: RWKV14Bを日本語AlpacaデータセットでLoRAして公開しました(ご家庭で動く!?)|shi3z|note
-0.510: Webブラウザ上で3D/2Dモデルをぬるぬる動かせる「Babylon.js 6」正式版に。レンダリング性能が最大50倍、WASM化した物理演算エンジン搭載、液体のレンダリングも - Publickey
Depending on the query, the results can differ quite a bit. I also implemented a hybrid search that ensembles kNN and SVM results, so let's look at that.
query: 大規模言語モデルを低スペックのマシンで動かしたい
=== kNN ===
-3.816: RWKV14Bを日本語AlpacaデータセットでLoRAして公開しました(ご家庭で動く!?)|shi3z|note
-3.527: チャットAIをブラウザのWebGPUだけで実行でき日本語も使用できる「Web LLM」、実際に試してみる方法はこんな感じ - GIGAZINE
-2.920: LLMをアプリ開発に統合するSDK「Semantic Kernel」がPythonに対応、TypeScriptへの対応も検討中|CodeZine(コードジン)
-2.591: ChatGPT対抗のオープンソース言語モデル「StableLM」。日本語版も? - PC Watch
-2.436: “画像の面白さ”を解説できるAI「MiniGPT-4」 写真からラップや詩、料理レシピ作成 デモサイトも公開中:Innovative Tech(1/2 ページ) - ITmedia NEWS
=== SVM ===
-3.923: 大規模言語モデルを自社でトレーニング&活用する方法|mah_lab / 西見 公宏|note
-3.865: 大規模言語モデル間の性能比較まとめ|mah_lab / 西見 公宏|note
-3.140: 深層学習コンパイラスタックと最適化
-3.097: LLMをアプリ開発に統合するSDK「Semantic Kernel」がPythonに対応、TypeScriptへの対応も検討中|CodeZine(コードジン)
-2.962: dolly-v2-12bという120億パラメータの言語モデルを使ってみた!|Masayuki Abe|note
=== Hybrid ===
-3.869: 大規模言語モデルを自社でトレーニング&活用する方法|mah_lab / 西見 公宏|note
-3.102: RWKV14Bを日本語AlpacaデータセットでLoRAして公開しました(ご家庭で動く!?)|shi3z|note
-2.913: 大規模言語モデル間の性能比較まとめ|mah_lab / 西見 公宏|note
-2.844: LLMをアプリ開発に統合するSDK「Semantic Kernel」がPythonに対応、TypeScriptへの対応も検討中|CodeZine(コードジン)
-2.558: チャットAIをブラウザのWebGPUだけで実行でき日本語も使用できる「Web LLM」、実際に試してみる方法はこんな感じ - GIGAZINE
This looks better to me. Since it is easy to try, using SVM in addition to kNN search or similarity search seems like a reasonable option. Of course kNN is overwhelmingly faster, but if SVM can be used at practical speed, it seems useful.
Extra code
embs must be created separately as an array of embeddings. texts is paired data for embs. LangChain's SVM Retriever makes this easier, but it does not expose scores, so I implemented it myself.
# base: https://github.com/karpathy/randomfun/blob/master/knn_vs_svm.ipynb
from sklearn import svm
import numpy as np
from langchain.embeddings import OpenAIEmbeddings
def knn_top_k(query_emb, embs, k=10):
l2_embs = embs / np.sqrt((embs**2).sum(1, keepdims=True))
l2_query = query_emb / np.sqrt((query_emb**2).sum())
similarities = l2_embs.dot(l2_query)
sorted_index = np.argsort(-similarities)
res_index = sorted_index[1:k+1]
return res_index, similarities[res_index], -similarities
def svm_top_k(query_emb, embs, k=10):
X = np.concatenate([query_emb[None, ...], embs])
y = np.zeros(X.shape[0])
y[0] = 1
clf = svm.LinearSVC(class_weight='balanced', verbose=False, max_iter=10000, tol=1e-6, C=0.1)
clf.fit(X, y)
similarities = clf.decision_function(X)
sorted_index = np.argsort(-similarities)
res_index = sorted_index[1:k+1] - 1
return res_index, similarities[res_index + 1], -similarities[1:]
def get_query_emb(text):
emb = OpenAIEmbeddings().embed_query(text) # type: ignore
return np.array(emb)
def join_colon(num_list_a, list_b):
return [f'{a:.3f}: {b}' for a, b in zip(num_list_a, list_b)]
def knn_svm(text, embs, texts, k=5):
query_emb = get_query_emb(text)
knn_index, knn_similarities, _ = knn_top_k(query_emb, embs, k)
svm_index, svm_similarities, _ = svm_top_k(query_emb, embs, k)
print('query: ', text)
print('=== kNN ===')
print("\n".join(join_colon(knn_similarities, texts[knn_index])))
print('=== SVM ===')
print("\n".join(join_colon(svm_similarities, texts[svm_index])))
def hyblid_knn_svm(text_or_emb, embs, texts, k=5):
if isinstance(text_or_emb, str):
query_emb = get_query_emb(text_or_emb)
print('query: ', text_or_emb) # type: ignore
else:
query_emb = text_or_emb
# 全件取得する
knn_index, knn_similarities, knn_all_scores = knn_top_k(query_emb, embs, embs.shape[0])
svm_index, svm_similarities, svm_all_scores = svm_top_k(query_emb, embs, embs.shape[0])
# score を正規化する
knn_score_normalized = (knn_all_scores - np.mean(knn_all_scores)) / np.std(knn_all_scores)
svm_score_normalized = (svm_all_scores - np.mean(svm_all_scores)) / np.std(svm_all_scores)
# それぞれのスコアを足し合わせて、ハイブリッドなスコアを作る
hybrid_similarities = (knn_score_normalized + svm_score_normalized) / 2
hybrid_index = np.argsort(hybrid_similarities)[:k]
print('=== kNN ===')
print("\n".join(join_colon(np.sort(knn_score_normalized)[:k], texts[knn_index][:k])))
print('=== SVM ===')
print("\n".join(join_colon(np.sort(svm_score_normalized)[:k], texts[svm_index][:k])))
print('=== Hybrid ===')
print("\n".join(join_colon(hybrid_similarities[hybrid_index][:k], texts[hybrid_index][:k])))