Fine-tuning Gemma for Function Calling

Welcome to this comprehensive guide on fine-tuning the Gemma model for function calling. This tutorial will walk you through the entire process, from setting up your environment to training and testing the model. We'll be using Torch XLA and Hugging Face's Transformer Reinforcement Learning (TRL) framework to achieve this.
Introduction to Gemma
Gemma is a family of lightweight, state-of-the-art open models from Google, designed for a variety of text generation tasks. These models are text-to-text, decoder-only large language models, available in English, and are well-suited for environments with limited resources.
Setting Up the Environment
Selecting the Runtime Environment
You can choose either Google Colab or Kaggle as your platform. For this guide, we'll focus on Kaggle.
Kaggle Setup
- Click Open in Kaggle.
- Click on Settings in the right sidebar.
- Under Accelerator, select TPUs.
- Save the settings, and the notebook will restart with TPU support.
Gemma using Hugging Face
- Create a Hugging Face Account: Sign up for a free account if you don't have one.
- Access the Gemma Model: Visit the Gemma model page and accept the usage conditions.
- Generate a Hugging Face Token: Go to your Hugging Face settings page and generate a new access token.
Configure Your Credentials
To access private models and datasets, you need to log in to the Hugging Face ecosystem.
Kaggle
- Open your Kaggle notebook and locate the Addons menu.
- Click on Secrets to manage your environment secrets.
- Add Hugging Face Token:
- Click on the Add secret button.
- In the Label field, enter
HF_TOKEN
. - In the Value field, paste your Hugging Face token.
- Click Save to add the secret.
import os
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
os.environ['HF_TOKEN'] = user_secrets.get_secret("HF_TOKEN")
Install Dependencies
Set up the environment by installing all the necessary Python packages for fine-tuning the Gemma model on a TPU VM using Torch XLA.
!pip install transformers==4.46.1 datasets==3.1.0 trl==0.12.0 peft==0.13.2 accelerate==0.34.0 torch~=2.5.0 torch_xla[tpu]~=2.5.0 tpu-info
Fine-tuning Gemma 2 for Function Calling
Initializing Gemma 2 Model
Initialize the AutoModelForCausalLM
from the transformers
library by loading a pre-trained Gemma 2 model from HuggingFace.
from transformers import AutoTokenizer, AutoModelForCausalLM
model_name = "google/gemma-2-2b-it"
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16)
model.config.use_cache = False
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.padding_side = "right"
Load a Dataset
Use an existing dataset from Hugging Face, such as lilacai/glaive-function-calling-v2-sharegpt
.
from datasets import load_dataset
dataset = load_dataset("lilacai/glaive-function-calling-v2-sharegpt", split="train[:15%]")
Create a Custom Chat Template
Define a chat template to convert conversations into a single tokenizable string.
chat_template = "{{ bos_token }}{% if messages[0]['from'] == 'system' %}{{'<start_of_turn>user\n' + messages[0]['value'] | trim + ' ' + messages[1]['value'] | trim + '<end_of_turn>\n'}}{% set messages = messages[2:] %}{% endif %}{% for message in messages %}{% if message['from'] == 'human' %}{{'<start_of_turn>user\n' + message['value'] | trim + '<end_of_turn>\n'}}{% elif message['from'] == 'gpt' %}{{'<start_of_turn>model\n' + message['value'] | trim + '<end_of_turn>\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<start_of_turn>model\n' }}{% endif %}"
tokenizer.chat_template = chat_template
Define the Formatting Function
Apply the template to each row in the dataset.
def formatting_prompts_func(examples):
convos = examples["conversations"]
texts = [tokenizer.apply_chat_template(convo, tokenize=False, add_generation_prompt=False) for convo in convos]
return { "text" : texts, }
dataset = dataset.map(formatting_prompts_func, batched=True)
LoRA Configuration
Set up LoRA (Low-Rank Adaptation) configuration.
from peft import LoraConfig
peft_config = LoraConfig(lora_alpha=16, lora_dropout=0, r=16, bias="none", task_type="CAUSAL_LM", target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"])
Set Training Configuration
Define training arguments for the model.
from trl import SFTTrainer, SFTConfig
training_arguments = SFTConfig(output_dir="./results", max_steps=100, per_device_train_batch_size=32, optim="adafactor", learning_rate=0.0002, bf16=True, max_seq_length=1024, dataset_text_field="text", packing=True, logging_steps=1, seed=42)
Train the Model
Use Huggingface TRL's SFTTrainer
class to train the model.
trainer = SFTTrainer(model=model, train_dataset=dataset, peft_config=peft_config, args=training_arguments)
trainer.train()
Save the Fine-tuned Model
After training, save the fine-tuned model.
trainer.model.to('cpu').save_pretrained("gemma-func-ft")
Prompt Using the Newly Fine-tuned Model
Reload the fine-tuned model and test it with a sample prompt.
base_model = AutoModelForCausalLM.from_pretrained(model_name, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16)
model = PeftModel.from_pretrained(base_model, "gemma-func-ft")
model = model.merge_and_unload()
input_text = "<start_of_turn>user\nYou are a helpful assistant..."
input_ids = tokenizer(input_text, return_tensors="pt").to("cpu")
outputs = model.generate(**input_ids, max_new_tokens=512)
print(tokenizer.decode(outputs[0]))
Conclusion
Congratulations! You've successfully fine-tuned Gemma for Function Calling using Torch XLA and PEFT with LoRA on TPUs. Explore further by experimenting with different datasets and tuning hyperparameters.
Reference: Google Gemma Cookbook by Google LLC.
Discuss Your Project with Us
We're here to help with your web development needs. Schedule a call to discuss your project and how we can assist you.
Let's find the best solutions for your needs.