Welcome EmbeddingGemma, Google’s new efficient embedding model

Dataemia
33 Min Read


Today, Google releases EmbeddingGemma, a state-of-the-art multilingual embedding model perfect for on-device use cases. Designed for speed and efficiency, the model features a compact size of 308M parameters and a 2K context window, unlocking new possibilities for mobile RAG pipelines, agents, and more. EmbeddingGemma is trained to support over 100 languages and is the highest-ranking text-only multilingual embedding model under 500M on the Massive Text Embedding Benchmark (MTEB) at the time of writing.



Table of Contents



Introduction

Text embeddings have become the backbone of modern natural‑language applications, turning words, sentences, and documents into dense vectors that capture meaning, sentiment, and intent. These vectors enable fast similarity search, clustering, classification, and retrieval across massive corpora, powering everything from recommendation engines and semantic search to retrieval-augmented generation and code‑search tools. Embedding models that calculate these embeddings are widely used, with well over 200 million monthly downloads on Hugging Face.

Building on this foundation, Google DeepMind’s EmbeddingGemma arrives as the newest, most capable small multilingual embedding model yet. With just 308M parameters, a 2k‑token context window, and support for over 100 languages, EmbeddingGemma delivers state‑of‑the‑art performance on the Massive Multilingual Text Embedding Benchmark (MMTEB) while staying under 200 MB of RAM when quantized.

The various design choices result in a very practical, open-source tool for computing high-quality multilingual embeddings on everyday devices.

In this blogpost, we describe the EmbeddingGemma architecture and training, and show you how to use the model with various frameworks like Sentence Transformers, LangChain, LlamaIndex, Haystack, txtai, Transformers.js, Text Embedding Inference, and ONNX.

Afterwards, we demonstrate how to finetune EmbeddingGemma on your domain for even stronger performance. In our example, we finetune EmbeddingGemma on the Medical Instruction and Retrieval Dataset (MIRIAD). The resulting model, sentence-transformers/embeddinggemma-300m-medical, achieves state-of-the-art performance on our task: retrieving passages of scientific medical papers in response to detailed medical questions. It even outperforms models twice as big on this task.



Architecture

EmbeddingGemma builds on the Gemma3 transformers backbone, but modified to use bi-directional attention instead of causal (one-way) attention. This means that earlier tokens in the sequence can attend to later tokens, effectively turning the architecture from a decoder into an encoder. Encoder models can outperform LLMs, which are decoders, on embedding tasks like retrieval (Weller et al., 2025). With this backbone, the model can process a sizable 2048 tokens at once, sufficient for typical retrieval inputs, especially given that larger inputs often result in information loss in the text embeddings.

Beyond the new Gemma3-based encoder backbone, which produces token embeddings, a mean pooling layer converts these token embeddings into text embeddings. Lastly, two dense layers transform the text embeddings into their final form, a 768-dimensional vector.

The EmbeddingGemma model has been trained with Matryoshka Representation Learning (MRL), allowing you to truncate the 768‑dimensional output to 512, 256, or 128 dimensions on demand. This results in faster downstream processing and lower memory and disk space utilization. See the Sentence Transformers usage for a snippet showing how to perform this truncation.

The model has been trained using a carefully curated, multilingual corpus totalling approximately 320 billion tokens. The proprietary dataset is a blend of publicly available web text, code and technical documentation, and synthetic task‑specific examples. It has been filtered to avoid Child Sexual Abuse Material (CSAM), sensitive data, and low-quality or unsafe content.



Evaluation

EmbeddingGemma was benchmarked on the MMTEB (Multilingual, v2) and MTEB (English, v2) suites, which span a wide range of tasks, domains, and languages. Despite its modest 308M‑parameter size, the model consistently beats comparable baselines while keeping a very small memory footprint.

MTEB (Multilingual, v2) Performance MTEB (English, v2) Performance

The results will be listed on the official MTEB Leaderboard. We exclude any model that has been trained on more than 20% of the MTEB data, to mitigate potential over‑fitting.



Demo


The demo can
also be experienced in full screen.


Experience the demo yourself on a Desktop device.



Usage

EmbeddingGemma is integrated with many popular tools, making it easy to incorporate into your existing workflows and applications. The model has been integrated in Sentence Transformers, and thus also in projects that use Sentence Transformers behind the scenes, such as LangChain, LlamaIndex, Haystack, and txtai. See the examples below to get started with your preferred framework.

For production deployments, you can use Text Embeddings Inference (TEI) to serve the model efficiently on various hardware configurations, and you can use Transformers.js for use in web applications.

Regardless of your framework choice, you should be mindful of the prompts. For embedding models, prompts are prepended to the input text to allow the model to distinguish between different tasks. EmbeddingGemma was trained with these prompt names and prompts, so they should also be included when using the model:

  • query: "task: search result | query: ",
  • document: "title: none | text: ",
  • BitextMining: "task: search result | query: ",
  • Clustering: "task: clustering | query: ",
  • Classification: "task: classification | query: ",
  • InstructionRetrieval: "task: code retrieval | query: ",
  • MultilabelClassification: "task: classification | query: ",
  • PairClassification: "task: sentence similarity | query: ",
  • Reranking: "task: search result | query: ",
  • Retrieval-query: "task: search result | query: ",
  • Retrieval-document: "title: none | text: ",
  • STS: "task: sentence similarity | query: ",
  • Summarization: "task: summarization | query: "

In Sentence Transformers, the query and document prompts are used automatically when calling model.encode_query and model.encode_document, but for other frameworks you might have to: $

  1. specify prompt names (e.g. “Reranking”),
  2. specify prompt strings (e.g. “task: search result | query: “), or
  3. manually prepend the prompts to your input text.

The following example scripts will demonstrate this with various frameworks.



Sentence Transformers

You will need to install the following packages:

pip install git+https://github.com/huggingface/transformers@v4.56.0-Embedding-Gemma-preview
pip install sentence-transformers>=5.0.0



Retrieval

Inference using Sentence Transformers is rather simple, see this example for semantic search:

from sentence_transformers import SentenceTransformer


model = SentenceTransformer("google/embeddinggemma-300m")


query = "Which planet is known as the Red Planet?"
documents = [
    "Venus is often called Earth's twin because of its similar size and proximity.",
    "Mars, known for its reddish appearance, is often referred to as the Red Planet.",
    "Jupiter, the largest planet in our solar system, has a prominent red spot.",
    "Saturn, famous for its rings, is sometimes mistaken for the Red Planet."
]
query_embeddings = model.encode_query(query)
document_embeddings = model.encode_document(documents)
print(query_embeddings.shape, document_embeddings.shape)



similarities = model.similarity(query_embeddings, document_embeddings)
print(similarities)



ranking = similarities.argsort(descending=True)[0]
print(ranking)

Click to see non-retrieval code

If you’re not looking to use this model for Information Retrieval, then you’re likely best off using the most general encode method together with the model prompt that best describes your downstream task out of these options:

  • BitextMining: Find translated sentence pairs in two languages.
  • Clustering: Find similar texts to group them together.
  • Classification: Assign predefined labels to texts.
  • InstructionRetrieval: Retrieve relevant code snippets based on natural language instructions.
  • MultilabelClassification: Assign multiple labels to texts.
  • PairClassification: Assign predefined labels to texts.
  • Reranking: Reorder search results based on relevance.
  • Retrieval-query: Retrieve documents based on a query.
  • Retrieval-document: Retrieve documents based on their content.
  • STS: Compute semantic textual similarity between texts.
  • Summarization: Generate concise summaries of texts.
from sentence_transformers import SentenceTransformer


model = SentenceTransformer("google/embeddinggemma-300m")


print(model.prompts)

















texts = [
    "The weather is beautiful today.",
    "It's a lovely day outside.",
    "The stock market crashed yesterday.",
    "I enjoy programming with Python."
]
embeddings = model.encode(texts, prompt_name="STS")
print(embeddings.shape)



similarities = model.similarity(embeddings, embeddings)
print(similarities)
"""
tensor([[1.0000, 0.9305, 0.4660, 0.4326],
        [0.9305, 1.0000, 0.4227, 0.4434],
        [0.4660, 0.4227, 1.0000, 0.2638],
        [0.4326, 0.4434, 0.2638, 1.0000]])
"""
Click to see how to truncate embedding dimensionality for faster and cheaper search

Because google/embeddinggemma-300m was trained with MRL, the embeddings generated by this model can be truncated to lower dimensionalities without considerably hurting the evaluation performance. Embeddings with lower dimensionalities are both cheaper to store on disk and in memory, as well as faster for downstream tasks like retrieval, clustering, or classification.

In Sentence Transformers, you can set a lower dimensionality using the truncate_dim parameter on either the SentenceTransformer initialization or when calling model.encode/model.encode_query/model.encode_document:

from sentence_transformers import SentenceTransformer


model = SentenceTransformer("google/embeddinggemma-300m", truncate_dim=256)


query = "Which planet is known as the Red Planet?"
documents = [
    "Venus is often called Earth's twin because of its similar size and proximity.",
    "Mars, known for its reddish appearance, is often referred to as the Red Planet.",
    "Jupiter, the largest planet in our solar system, has a prominent red spot.",
    "Saturn, famous for its rings, is sometimes mistaken for the Red Planet."
]
query_embeddings = model.encode_query(query)
document_embeddings = model.encode_document(documents)
print(query_embeddings.shape, document_embeddings.shape)



similarities = model.similarity(query_embeddings, document_embeddings)
print(similarities)



ranking = similarities.argsort(descending=True)[0]
print(ranking)

Note that the ranking is preserved despite using 3x smaller embeddings compared to the full-sized embeddings.



LangChain

If you prefer, you can also use the LangChain HuggingFaceEmbeddings, which uses Sentence Transformers behind the scenes. Note that you’ll have to tell LangChain to use the prompts called “query” and “document” for queries and documents, respectively. This example involves a simple information retrieval setup, but the same embedding model can be used in more complex scenarios.

You will need to install the following packages:

pip install git+https://github.com/huggingface/transformers@v4.56.0-Embedding-Gemma-preview
pip install sentence-transformers
pip install langchain
pip install langchain-community
pip install langchain-huggingface
pip install faiss-cpu
from langchain.docstore.document import Document
from langchain_community.vectorstores import FAISS
from langchain_huggingface.embeddings import HuggingFaceEmbeddings




embedder = HuggingFaceEmbeddings(
    model_name="google/embeddinggemma-300m",
    query_encode_kwargs={"prompt_name": "query"},
    encode_kwargs={"prompt_name": "document"}
)

data = [
    "Venus is often called Earth's twin because of its similar size and proximity.",
    "Mars, known for its reddish appearance, is often referred to as the Red Planet.",
    "Jupiter, the largest planet in our solar system, has a prominent red spot.",
    "Saturn, famous for its rings, is sometimes mistaken for the Red Planet."
]


documents = [Document(page_content=text, metadata={"id": i}) for i, text in enumerate(data)]




vector_store = FAISS.from_documents(documents, embedder, distance_strategy="MAX_INNER_PRODUCT")


query = "Which planet is known as the Red Planet?"
results = vector_store.similarity_search_with_score(query, k=3)


for doc, score in results:
    print(f"Text: {doc.page_content} (score: {score:.4f})")
"""
Text: Mars, known for its reddish appearance, is often referred to as the Red Planet. (score: 0.6359)
Text: Jupiter, the largest planet in our solar system, has a prominent red spot. (score: 0.4930)
Text: Saturn, famous for its rings, is sometimes mistaken for the Red Planet. (score: 0.4889)
"""



LlamaIndex

EmbeddingGemma is also supported in LlamaIndex as it uses Sentence Transformers under the hood. For the correct behaviour, you need to specify the query and document prompts as defined in the model configuration. Otherwise, your performance will be suboptimal. This script shows a rudimentary example of using EmbeddingGemma with LlamaIndex, but you can use the HuggingFaceEmbedding class in more difficult settings also.

You will need to install the following packages:

pip install git+https://github.com/huggingface/transformers@v4.56.0-Embedding-Gemma-preview
pip install sentence-transformers
pip install llama-index
pip install llama-index-embeddings-huggingface
pip install llama-index-vector-stores-faiss
import faiss
from llama_index.core.schema import TextNode
from llama_index.core.vector_stores import VectorStoreQuery
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.vector_stores.faiss import FaissVectorStore




embeddings = HuggingFaceEmbedding(
    model_name="google/embeddinggemma-300m",
    query_instruction="task: search result | query: ",
    text_instruction="title: none | text: ",
)

data = [
    "Venus is often called Earth's twin because of its similar size and proximity.",
    "Mars, known for its reddish appearance, is often referred to as the Red Planet.",
    "Jupiter, the largest planet in our solar system, has a prominent red spot.",
    "Saturn, famous for its rings, is sometimes mistaken for the Red Planet."
]


store = FaissVectorStore(faiss_index=faiss.IndexFlatIP(768))
store.add([TextNode(id=i, text=text, embedding=embeddings.get_text_embedding(text)) for i, text in enumerate(data)])


query = "Which planet is known as the Red Planet?"
query_embedding = embeddings.get_query_embedding(query)
results = store.query(VectorStoreQuery(query_embedding=query_embedding, similarity_top_k=3))


for idx, score in zip(results.ids, results.similarities):
    print(f"Text: {data[int(idx)]} (score: {score:.4f})")
"""
Text: Mars, known for its reddish appearance, is often referred to as the Red Planet. (score: 0.6359)
Text: Jupiter, the largest planet in our solar system, has a prominent red spot. (score: 0.4930)
Text: Saturn, famous for its rings, is sometimes mistaken for the Red Planet. (score: 0.4889)
"""



Haystack

EmbeddingGemma can also be used with Haystack, a framework for building production-ready search and language applications. Like LangChain and LlamaIndex, Haystack uses Sentence Transformers behind the scenes and requires you to specify the appropriate prompts. The following example shows how to set up a basic retrieval pipeline using EmbeddingGemma with Haystack.

You will need to install the following packages:

pip install git+https://github.com/huggingface/transformers@v4.56.0-Embedding-Gemma-preview
pip install sentence-transformers
pip install haystack-ai
from haystack import Document, Pipeline
from haystack.components.embedders import SentenceTransformersDocumentEmbedder, SentenceTransformersTextEmbedder
from haystack.components.retrievers import InMemoryEmbeddingRetriever
from haystack.document_stores.in_memory import InMemoryDocumentStore


document_store = InMemoryDocumentStore()


document_embedder = SentenceTransformersDocumentEmbedder(
    model="google/embeddinggemma-300m", encode_kwargs={"prompt_name": "document"}
)
query_embedder = SentenceTransformersTextEmbedder(
    model="google/embeddinggemma-300m", encode_kwargs={"prompt_name": "query"}
)
document_embedder.warm_up()
query_embedder.warm_up()

data = [
    "Venus is often called Earth's twin because of its similar size and proximity.",
    "Mars, known for its reddish appearance, is often referred to as the Red Planet.",
    "Jupiter, the largest planet in our solar system, has a prominent red spot.",
    "Saturn, famous for its rings, is sometimes mistaken for the Red Planet.",
]


documents = [Document(content=text, id=str(i)) for i, text in enumerate(data)]
documents_with_embeddings = document_embedder.run(documents=documents)["documents"]
document_store.write_documents(documents_with_embeddings)


query_pipeline = Pipeline()
query_pipeline.add_component("text_embedder", query_embedder)
query_pipeline.add_component("retriever", InMemoryEmbeddingRetriever(document_store=document_store, top_k=3))
query_pipeline.connect("text_embedder.embedding", "retriever.query_embedding")


query = "Which planet is known as the Red Planet?"
results = query_pipeline.run({"text_embedder": {"text": query}})


for document in results["retriever"]["documents"]:
    print(f"Text: {document.content} (score: {document.score:.4f})")
"""
Text: Mars, known for its reddish appearance, is often referred to as the Red Planet. (score: 0.6359)
Text: Jupiter, the largest planet in our solar system, has a prominent red spot. (score: 0.4930)
Text: Saturn, famous for its rings, is sometimes mistaken for the Red Planet. (score: 0.4889)
"""



txtai

txtai is also compatible with EmbeddingGemma. Like other frameworks, txtai utilizes Sentence Transformers under the hood and needs the appropriate prompts for optimal performance with EmbeddingGemma. The following example demonstrates how to set up a basic retrieval system with txtai.

You will need to install the following packages:

pip install git+https://github.com/huggingface/transformers@v4.56.0-Embedding-Gemma-preview
pip install sentence-transformers
pip install txtai
from txtai import Embeddings




embeddings = Embeddings(
    path="google/embeddinggemma-300m",
    method="sentence-transformers",
    instructions={
        "query": "task: search result | query: ",
        "data": "title: none | text: ",
    }
)

data = [
    "Venus is often called Earth's twin because of its similar size and proximity.",
    "Mars, known for its reddish appearance, is often referred to as the Red Planet.",
    "Jupiter, the largest planet in our solar system, has a prominent red spot.",
    "Saturn, famous for its rings, is sometimes mistaken for the Red Planet."
]


embeddings.index(data)


query = "Which planet is known as the Red Planet?"
results = embeddings.search(query, 3)


for idx, score in results:
    print(f"Text: {data[int(idx)]} (score: {score:.4f})")
"""
Text: Mars, known for its reddish appearance, is often referred to as the Red Planet. (score: 0.6359)
Text: Jupiter, the largest planet in our solar system, has a prominent red spot. (score: 0.4930)
Text: Saturn, famous for its rings, is sometimes mistaken for the Red Planet. (score: 0.4889)
"""



Transformers.js

You can even run EmbeddingGemma 100% locally in your browser with Transformers.js! If you haven’t already, you can install the library from NPM using:

npm i @huggingface/transformers

You can then compute embeddings as follows:

import { AutoModel, AutoTokenizer, matmul } from "@huggingface/transformers";


const model_id = "onnx-community/embeddinggemma-300m-ONNX";
const tokenizer = await AutoTokenizer.from_pretrained(model_id);
const model = await AutoModel.from_pretrained(model_id, {
  dtype: "fp32", 
});


const prefixes = {
  query: "task: search result | query: ",
  document: "title: none | text: ",
};
const query = prefixes.query + "Which planet is known as the Red Planet?";
const documents = [
  "Venus is often called Earth's twin because of its similar size and proximity.",
  "Mars, known for its reddish appearance, is often referred to as the Red Planet.",
  "Jupiter, the largest planet in our solar system, has a prominent red spot.",
  "Saturn, famous for its rings, is sometimes mistaken for the Red Planet.",
].map((x) => prefixes.document + x);

const inputs = await tokenizer([query, ...documents], { padding: true });
const { sentence_embedding } = await model(inputs);


const scores = await matmul(sentence_embedding, sentence_embedding.transpose(1, 0));
const similarities = scores.tolist()[0].slice(1);
console.log(similarities);



const ranking = similarities.map((score, index) => ({ index, score })).sort((a, b) => b.score - a.score);
console.log(ranking);








Text Embeddings Inference

You can easily deploy EmbeddingGemma for both development and production using Text Embeddings Inference (TEI) version 1.8.1 or later.

docker run -p 8080:80 ghcr.io/huggingface/text-embeddings-inference:cpu-1.8.1 --model-id google/embeddinggemma-300m --dtype float32
docker run -p 8080:80 ghcr.io/huggingface/text-embeddings-inference:cpu-1.8.1 --model-id onnx-community/embeddinggemma-300m-ONNX --dtype float32 --pooling mean
docker run --gpus all --shm-size 1g -p 8080:80 ghcr.io/huggingface/text-embeddings-inference:cuda-1.8.1 --model-id google/embeddinggemma-300m --dtype float32

If you run the Docker container with the cuda-1.8.1 tag, it includes support for multiple GPU architectures: Turing, Ampere, Ada Lovelace, and Hopper. For a lighter image tailored to just your GPU, you can instead use a specific tag such as turing-1.8.1, 1.8.1 and 86-1.8.1 (Ampere), 89-1.8.1 (Ada Lovelace), or hopper-1.8.1.

Once deployed, regardless of the device or runtime, you can leverage the /v1/embeddings endpoint based on the OpenAI Embeddings API Specification to generate embeddings.

curl http://0.0.0.0:8080/v1/embeddings -H "Content-Type: application/json" -d '{"model":"google/embeddinggemma-300m","input":["task: search result | query: Which planet is known as the Red Planet?","task: search result | query: Where did Amelia Earhart first fly?"]}'

Alternatively, you can also leverage the /embed endpoint from the Text Embeddings Inference Embeddings API, which supports the prompt_name parameter, meaning there’s no need to manually prepend the prompt to the inputs but select it via prompt_name instead.

curl http://0.0.0.0:8080/embed -H "Content-Type: application/json" -d '{"inputs":["Which planet is known as the Red Planet?","Where did Amelia Earthart first fly?"],"prompt_name":"query","normalize":true}'

Additionally, note that since google/embeddinggemma-300m was trained with Matryoshka Representation Learning (MRL), you can also leverage the dimensions parameter, on both /v1/embeddings and /embed, to truncate the embeddings to lower dimensionalities (512, 256, and 128) without hurting the evaluation performance.



ONNX Runtime

You can also run the model directly with ONNX Runtime, making it highly portable and cross-platform compatible. The example below shows usage in Python, but the same approach can be applied in other languages (Java, C#, C++, etc.) as well.

from huggingface_hub import hf_hub_download
import onnxruntime as ort
from transformers import AutoTokenizer


model_id = "onnx-community/embeddinggemma-300m-ONNX"
model_path = hf_hub_download(model_id, subfolder="onnx", filename="model.onnx") 
hf_hub_download(model_id, subfolder="onnx", filename="model.onnx_data") 
session = ort.InferenceSession(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_id)


prefixes = {
  "query": "task: search result | query: ",
  "document": "title: none | text: ",
}
query = prefixes["query"] + "Which planet is known as the Red Planet?"
documents = [
    "Venus is often called Earth's twin because of its similar size and proximity.",
    "Mars, known for its reddish appearance, is often referred to as the Red Planet.",
    "Jupiter, the largest planet in our solar system, has a prominent red spot.",
    "Saturn, famous for its rings, is sometimes mistaken for the Red Planet."
]
documents = [prefixes["document"] + x for x in documents]

inputs = tokenizer([query] + documents, padding=True, return_tensors="np")

_, sentence_embedding = session.run(None, inputs.data)
print(sentence_embedding.shape)  


query_embeddings = sentence_embedding[0]
document_embeddings = sentence_embedding[1:]
similarities = query_embeddings @ document_embeddings.T
print(similarities)  


ranking = similarities.argsort()[::-1]
print(ranking)  



Finetuning

As with all models compatible with the Sentence Transformers library, EmbeddingGemma can be easily fine-tuned on your specific dataset. To showcase this, we’ll be finetuning google/embeddinggemma-300m on the Medical Instruction and RetrIeval Dataset (MIRIAD) dataset, such that our finetuned model becomes particularly adept at finding passages up to 1000 tokens from scientific medical papers given detailed medical questions. These passages can be used as crucial context for a generative model to answer questions more effectively.

Below, you can explore each key component of the finetuning process using expandable tabs. Each tab contains the relevant code and a detailed explanation.

Model
from sentence_transformers import SentenceTransformer, SentenceTransformerModelCardData

model = SentenceTransformer(
    "google/embeddinggemma-300m",
    model_card_data=SentenceTransformerModelCardData(
        language="en",
        license="apache-2.0",
        model_name="EmbeddingGemma-300m trained on the Medical Instruction and RetrIeval Dataset (MIRIAD)",
    ),
)







This code loads the EmbeddingGemma model from Hugging Face, with optional model card metadata for documentation and sharing. The SentenceTransformer class loads the model weights and configuration, while the model_card_data argument attaches metadata useful for inclusion in the automatically generated model card.

Dataset
from datasets import load_dataset

train_dataset = load_dataset("tomaarsen/miriad-4.4M-split", split="train").select(range(100_000))
eval_dataset = load_dataset("tomaarsen/miriad-4.4M-split", split="eval").select(range(1_000))
test_dataset = load_dataset("tomaarsen/miriad-4.4M-split", split="test").select(range(1_000))












This code loads the MIRIAD dataset, or rather, a copy that has been divided into train, eval, and test splits. Using a large, high-quality dataset ensures the model learns meaningful representations, while subsetting allows for faster experimentation. The load_dataset function fetches the dataset from Hugging Face Datasets, and the .select() method limits the number of samples for each split.

Loss Function
from sentence_transformers.losses import CachedMultipleNegativesRankingLoss

loss = CachedMultipleNegativesRankingLoss(model, mini_batch_size=8)

This code defines the loss function for training, using Cached Multiple Negatives Ranking Loss (CMNRL). CMNRL is effective for retrieval tasks, as it uses in-batch negatives to efficiently train the model to distinguish between correct and incorrect pairs. The loss takes question-answer pairs and treats other answers in the batch as negatives, maximizing the distance between unrelated pairs in the embedding space. The mini_batch_size parameter controls the memory usage, but does not affect the training dynamics.

It’s recommended to use this loss with a large per_device_train_batch_size in SentenceTransformerTrainingArguments and a low mini_batch_size in CachedMultipleNegativesRankingLoss for a strong training signal with low memory usage. Additionally, the NO_DUPLICATES batch sampler is recommended to avoid accidental false negatives.

Training Arguments
from sentence_transformers.training_args import BatchSamplers
from sentence_transformers import SentenceTransformerTrainingArguments

run_name = "embeddinggemma-300m-medical-100k"
args = SentenceTransformerTrainingArguments(
    output_dir=f"models/{run_name}",
    num_train_epochs=1,
    per_device_train_batch_size=128,
    per_device_eval_batch_size=128,
    learning_rate=2e-5,
    warmup_ratio=0.1,
    fp16=True,  
    bf16=False,  
    batch_sampler=BatchSamplers.NO_DUPLICATES,
    prompts={
        "question": model.prompts["query"],
        "passage_text": model.prompts["document"],
    },
    eval_strategy="steps",
    eval_steps=100,
    save_strategy="steps",
    save_steps=100,
    save_total_limit=2,
    logging_steps=20,
    run_name=run_name,
)

This code sets up all hyperparameters and configuration for training, evaluation, and logging. Proper training arguments are crucial for efficient, stable, and reproducible training. The arguments control batch sizes, learning rate, mixed precision, evaluation and saving frequency, and more. Notably, the prompts dictionary maps dataset columns to prompts used by the model to distinguish queries from documents.

Evaluator
from sentence_transformers.evaluation import InformationRetrievalEvaluator

queries = dict(enumerate(eval_dataset["question"]))
corpus = dict(enumerate(eval_dataset["passage_text"] + train_dataset["passage_text"][:30_000]))
relevant_docs = {idx: [idx] for idx in queries}
dev_evaluator = InformationRetrievalEvaluator(
    queries=queries,
    corpus=corpus,
    relevant_docs=relevant_docs,
    name="miriad-eval-1kq-31kd",
    show_progress_bar=True,
)
dev_evaluator(model)

This code sets up an evaluator for information retrieval, using queries and a corpus to measure model performance. Evaluation during training helps monitor progress and avoid overfitting. The evaluator computes retrieval metrics (NDCG, MRR, Recall, Precision, MAP, etc.) by checking if the model retrieves the correct passages for each query. It can be run before, during, and after training, and the results will be logged and incorporated in the automatically generated model card.

Note that this snippet in particular uses all (1k) evaluation questions against a corpus of all (1k) evaluation passages and 30k training passages, for a total of 31k documents. Evaluating only against evaluation passages is too simple for the model.

Trainer
from sentence_transformers import SentenceTransformerTrainer

trainer = SentenceTransformerTrainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    loss=loss,
    evaluator=dev_evaluator,
)
trainer.train()

This code initializes and runs the training loop, coordinating all components.



Full Finetuning Script

Below is the complete script, combining all components above:

import logging
import traceback

from datasets import load_dataset
from sentence_transformers import (
    SentenceTransformer,
    SentenceTransformerModelCardData,
    SentenceTransformerTrainer,
    SentenceTransformerTrainingArguments,
)
from sentence_transformers.evaluation import InformationRetrievalEvaluator
from sentence_transformers.losses import CachedMultipleNegativesRankingLoss
from sentence_transformers.training_args import BatchSamplers


logging.basicConfig(format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO)


model = SentenceTransformer(
    "google/embeddinggemma-300m",
    model_card_data=SentenceTransformerModelCardData(
        language="en",
        license="apache-2.0",
        model_name="EmbeddingGemma-300m trained on the Medical Instruction and RetrIeval Dataset (MIRIAD)",
    ),
)


train_dataset = load_dataset("tomaarsen/miriad-4.4M-split", split="train").select(range(100_000))
eval_dataset = load_dataset("tomaarsen/miriad-4.4M-split", split="eval").select(range(1_000))
test_dataset = load_dataset("tomaarsen/miriad-4.4M-split", split="test").select(range(1_000))









loss = CachedMultipleNegativesRankingLoss(model, mini_batch_size=8)


run_name = "embeddinggemma-300m-medical-100k"
args = SentenceTransformerTrainingArguments(
    
    output_dir=f"models/{run_name}",
    
    num_train_epochs=1,
    per_device_train_batch_size=128,
    per_device_eval_batch_size=128,
    learning_rate=2e-5,
    warmup_ratio=0.1,
    fp16=True,  
    bf16=False,  
    batch_sampler=BatchSamplers.NO_DUPLICATES,  
    prompts={  
        "question": model.prompts["query"],
        "passage_text": model.prompts["document"],
    },
    
    eval_strategy="steps",
    eval_steps=100,
    save_strategy="steps",
    save_steps=100,
    save_total_limit=2,
    logging_steps=20,
    run_name=run_name,  
)


queries = dict(enumerate(eval_dataset["question"]))
corpus = dict(enumerate(eval_dataset["passage_text"] + train_dataset["passage_text"][:30_000]))
relevant_docs = {idx: [idx] for idx in queries}
dev_evaluator = InformationRetrievalEvaluator(
    queries=queries,
    corpus=corpus,
    relevant_docs=relevant_docs,
    name="miriad-eval-1kq-31kd",  
    show_progress_bar=True,
)
dev_evaluator(model)


trainer = SentenceTransformerTrainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    loss=loss,
    evaluator=dev_evaluator,
)
trainer.train()



dev_evaluator(model)

queries = dict(enumerate(test_dataset["question"]))
corpus = dict(enumerate(test_dataset["passage_text"] + train_dataset["passage_text"][:30_000]))
relevant_docs = {idx: [idx] for idx in queries}
test_evaluator = InformationRetrievalEvaluator(
    queries=queries,
    corpus=corpus,
    relevant_docs=relevant_docs,
    name="miriad-test-1kq-31kd",  
    show_progress_bar=True,
)
test_evaluator(model)


final_output_dir = f"models/{run_name}/final"
model.save_pretrained(final_output_dir)



try:
    model.push_to_hub(run_name)
except Exception:
    logging.error(
        f"Error uploading model to the Hugging Face Hub:\n{traceback.format_exc()}To upload it manually, you can run "
        f"`huggingface-cli login`, followed by loading the model using `model = SentenceTransformer({final_output_dir!r})` "
        f"and saving it using `model.push_to_hub('{run_name}')`."
    )



Training

We ran the full training script on an RTX 3090 with 24GB of VRAM, and the completed training and evaluating scripts took 5.5 hours. If desired, you can further reduce the memory footprint by reducing mini_batch_size on the CachedMultipleNegativesRankingLoss and batch_size on the InformationRetrievalEvaluator instances. See here the logs from our training run:

Epoch Step Training Loss Validation Loss miriad-eval-1kq-31kd_cosine_ndcg@10 miriad-test-1kq-31kd_cosine_ndcg@10
-1 -1 0.8474 0.8340
0.0256 20 0.1019
0.0512 40 0.0444
0.0767 60 0.0408
0.1023 80 0.0462
0.1279 100 0.0542 0.0525 0.8616
0.1535 120 0.0454
0.1790 140 0.0403
0.2046 160 0.0463
0.2302 180 0.0508
0.2558 200 0.0497 0.0449 0.8643
0.2813 220 0.0451
0.3069 240 0.0445
0.3325 260 0.0489
0.3581 280 0.0452
0.3836 300 0.0461 0.0406 0.8832
0.4092 320 0.0415
0.4348 340 0.04
0.4604 360 0.0399
0.4859 380 0.0423
0.5115 400 0.0352 0.0316 0.8823
0.5371 420 0.0408
0.5627 440 0.0356
0.5882 460 0.0371
0.6138 480 0.0276
0.6394 500 0.028 0.0280 0.8807
0.6650 520 0.0302
0.6905 540 0.0345
0.7161 560 0.0325
0.7417 580 0.033
0.7673 600 0.0314 0.0264 0.8910
0.7928 620 0.033
0.8184 640 0.029
0.8440 660 0.0396
0.8696 680 0.0266
0.8951 700 0.0262 0.0240 0.8968
0.9207 720 0.0262
0.9463 740 0.0327
0.9719 760 0.0293
0.9974 780 0.0304
-1 -1 0.9026 0.8862



Finetuned Evaluation

The performance of the base model was already excellent, with a strong 0.8340 NDCG@10 on our MIRIAD test set. Despite that, we were able to increase it considerably on this domain-specific dataset.

Our fine-tuning process achieved a significant improvement of +0.0522 NDCG@10 on the test set, resulting in a model that comfortably outperforms any existing general-purpose embedding model on our specific task, at this model size. Additional time and compute investment would allow for even stronger results, such as hard negatives mining or training with more than 100k data pairs.



Further Reading



Source link

Share This Article
Leave a Comment

Leave a Reply

Your email address will not be published. Required fields are marked *

error: Content is protected !!