Fine-tuning Gemma for Function Calling

Anis MarrouchiAI Bot
By Anis Marrouchi & AI Bot ·

Loading the Text to Speech Audio Player...

Welcome to this comprehensive guide on fine-tuning the Gemma model for function calling. This tutorial will walk you through the entire process, from setting up your environment to training and testing the model. We'll be using Torch XLA and Hugging Face's Transformer Reinforcement Learning (TRL) framework to achieve this.

Introduction to Gemma

Gemma is a family of lightweight, state-of-the-art open models from Google, designed for a variety of text generation tasks. These models are text-to-text, decoder-only large language models, available in English, and are well-suited for environments with limited resources.

Setting Up the Environment

Selecting the Runtime Environment

You can choose either Google Colab or Kaggle as your platform. For this guide, we'll focus on Kaggle.

Kaggle Setup

  1. Click Open in Kaggle.
  2. Click on Settings in the right sidebar.
  3. Under Accelerator, select TPUs.
  4. Save the settings, and the notebook will restart with TPU support.

Gemma using Hugging Face

  1. Create a Hugging Face Account: Sign up for a free account if you don't have one.
  2. Access the Gemma Model: Visit the Gemma model page and accept the usage conditions.
  3. Generate a Hugging Face Token: Go to your Hugging Face settings page and generate a new access token.

Configure Your Credentials

To access private models and datasets, you need to log in to the Hugging Face ecosystem.

Kaggle

  1. Open your Kaggle notebook and locate the Addons menu.
  2. Click on Secrets to manage your environment secrets.
  3. Add Hugging Face Token:
    • Click on the Add secret button.
    • In the Label field, enter HF_TOKEN.
    • In the Value field, paste your Hugging Face token.
    • Click Save to add the secret.
import os
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
os.environ['HF_TOKEN'] = user_secrets.get_secret("HF_TOKEN")

Install Dependencies

Set up the environment by installing all the necessary Python packages for fine-tuning the Gemma model on a TPU VM using Torch XLA.

!pip install transformers==4.46.1 datasets==3.1.0 trl==0.12.0 peft==0.13.2 accelerate==0.34.0 torch~=2.5.0 torch_xla[tpu]~=2.5.0 tpu-info

Fine-tuning Gemma 2 for Function Calling

Initializing Gemma 2 Model

Initialize the AutoModelForCausalLM from the transformers library by loading a pre-trained Gemma 2 model from HuggingFace.

from transformers import AutoTokenizer, AutoModelForCausalLM
model_name = "google/gemma-2-2b-it"
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16)
model.config.use_cache = False
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.padding_side = "right"

Load a Dataset

Use an existing dataset from Hugging Face, such as lilacai/glaive-function-calling-v2-sharegpt.

from datasets import load_dataset
dataset = load_dataset("lilacai/glaive-function-calling-v2-sharegpt", split="train[:15%]")

Create a Custom Chat Template

Define a chat template to convert conversations into a single tokenizable string.

chat_template = "{{ bos_token }}{% if messages[0]['from'] == 'system' %}{{'<start_of_turn>user\n' + messages[0]['value'] | trim + ' ' + messages[1]['value'] | trim + '<end_of_turn>\n'}}{% set messages = messages[2:] %}{% endif %}{% for message in messages %}{% if message['from'] == 'human' %}{{'<start_of_turn>user\n' + message['value'] | trim + '<end_of_turn>\n'}}{% elif message['from'] == 'gpt' %}{{'<start_of_turn>model\n' + message['value'] | trim + '<end_of_turn>\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<start_of_turn>model\n' }}{% endif %}"
tokenizer.chat_template = chat_template

Define the Formatting Function

Apply the template to each row in the dataset.

def formatting_prompts_func(examples):
    convos = examples["conversations"]
    texts = [tokenizer.apply_chat_template(convo, tokenize=False, add_generation_prompt=False) for convo in convos]
    return { "text" : texts, }
dataset = dataset.map(formatting_prompts_func, batched=True)

LoRA Configuration

Set up LoRA (Low-Rank Adaptation) configuration.

from peft import LoraConfig
peft_config = LoraConfig(lora_alpha=16, lora_dropout=0, r=16, bias="none", task_type="CAUSAL_LM", target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"])

Set Training Configuration

Define training arguments for the model.

from trl import SFTTrainer, SFTConfig
training_arguments = SFTConfig(output_dir="./results", max_steps=100, per_device_train_batch_size=32, optim="adafactor", learning_rate=0.0002, bf16=True, max_seq_length=1024, dataset_text_field="text", packing=True, logging_steps=1, seed=42)

Train the Model

Use Huggingface TRL's SFTTrainer class to train the model.

trainer = SFTTrainer(model=model, train_dataset=dataset, peft_config=peft_config, args=training_arguments)
trainer.train()

Save the Fine-tuned Model

After training, save the fine-tuned model.

trainer.model.to('cpu').save_pretrained("gemma-func-ft")

Prompt Using the Newly Fine-tuned Model

Reload the fine-tuned model and test it with a sample prompt.

base_model = AutoModelForCausalLM.from_pretrained(model_name, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16)
model = PeftModel.from_pretrained(base_model, "gemma-func-ft")
model = model.merge_and_unload()
input_text = "<start_of_turn>user\nYou are a helpful assistant..."
input_ids = tokenizer(input_text, return_tensors="pt").to("cpu")
outputs = model.generate(**input_ids, max_new_tokens=512)
print(tokenizer.decode(outputs[0]))

Conclusion

Congratulations! You've successfully fine-tuned Gemma for Function Calling using Torch XLA and PEFT with LoRA on TPUs. Explore further by experimenting with different datasets and tuning hyperparameters.


Reference: Google Gemma Cookbook by Google LLC.


Want to read more tutorials? Check out our latest tutorial on R Programming for Bioinformatics: Learn the Basics.

Discuss Your Project with Us

We're here to help with your web development needs. Schedule a call to discuss your project and how we can assist you.

Let's find the best solutions for your needs.