Mastering MongoDB Queries with OpenAI: A Guide to Generating Queries for Desired Outputs

Image by DANIEL DIAZ from Pixabay

I’m a big fan of mongoplayground.net as it is one of the most useful and widely used tools for swiftly testing various MongoDB queries while working on a problem. Although this tool is really handy and convenient to use, I have always wondered if there is a way to automatically generate the query for a given input and an expected output, rather than having to manually construct these queries.

Thanks to Generative AI tools, this is now possible!

However, simply generating queries isn’t sufficient. For this tool to be really useful, the generated query must be accurate and produce the expected output. Therefore, we’ll explore employing function calling as a means to verify that the generated query is correct.

In this article, we will cover the following:

Let’s dive in and generate some MongoDB queries! Look at the “Code” section at the end of the article for the Github links to the Jupyter notebooks for the examples we will discuss below.

Auto-generate MongoDB queries

In this use case, we will take in two JSON arrays as input from the user – input_data and expected_output. The input_data JSON array contain a few example instances of the data and the expected_output contains the expected query output as a JSON array. The model’s task is to generate a MongoDB Query that when applied on this input_data will produce the given expected_data.

We will use OpenAI’s Chat Completion API and use either the gpt-3.5-turbo-1106 or gpt-4-1106-preview model for this use case.

Install and import the required packages

!pip install openai
!pip install python-dotenv

from openai import OpenAI
import os
from dotenv import load_dotenv

Place the OpenAI API key in a .env file

The OpenAI API key needs to be placed in a .env file, so that it is available as an environment variable.

OPENAI_API_KEY="sk-TZP9XNL34234HNKJiohsdl"

Instantiate the OpenAI client

_ = load_dotenv()
openai_client = OpenAI()

GPT3_MODEL = "gpt-3.5-turbo-1106"
GPT4_MODEL = "gpt-4-1106-preview"

System and User Prompt to auto-generate MongoDB queries

The system prompt is self-explanatory. In the user prompt, there are a few key things to note.

  • First, we instruct the model to specifically produce an MongoDB aggregation pipeline always.
  • Second, we ask it to produce a JSON response. Since, we are using the latest GPT models, they support producing a valid JSON response. We define the fields that we need in the response.
    • mongoDBQuery – In here, we explicitly specify that this field should contain an array of pipeline stages and no other prefix.
    • queryExplanation – This is a text summary explaining the generated query
  • Third, we pass in the input_data and the expected_data provided by the user.
def get_system_prompt():
    return f"""You are a MongoDB expert with great expertise in \
    writing MongoDB queries for any given data \
    to produce an expected output.
    """

def get_user_prompt(input_data, output_data):
    return f"""Your task is to write a MongoDB Query, \
    specifically an aggregation pipeline that would \
    produce the expected output for the given input.

    You will always return a JSON response with the following fields.
    ```
    mongoDBQuery: The MongoDB aggregation pipeline to produce\ 
    the expected output for a given input. This field corresponds \
    to just the list of stages in the aggregation pipeline \
    and shouldn't contain the "db.collection.aggregate" prefix.
    
    queryExplanation: A detailed explanation for the query \
    that was returned.
    ```
    
    Input data: {input_data} 
    Expected output data: {output_data}
    """

Define a function to invoke the Chat Completion API

We pass in response_format={“type”: “json_object”} to the API in order to get a valid JSON response from the model.

def get_mongodb_query(input_data, output_data, model=GPT3_MODEL):
    system_prompt = get_system_prompt()
    user_prompt = get_user_prompt(input_data, output_data)
    
    messages = []
    messages.append({"role": "system", "content": system_prompt})
    messages.append({"role": "user", "content": user_prompt})
    
    chat_completion = openai_client.chat.completions.create(
        model=model,
        messages=messages,
        temperature=0,
        response_format={"type": "json_object"}
    )

    print(f"Assistant Response:\n{chat_completion.choices[0].message.content}")

Example data

Here is a simple example for testing. In this example, the input data is a array of player objects containing the name of the player and the team they belong to. The expected output is a grouping of the players data by team and count the number of players in each team.

ex1_input_data = """
[
  {
    "name": "Sachin",
    "team": "India"
  },
  {
    "name": "Sourav",
    "team": "India"
  },
  {
    "name": "Lara",
    "team": "West Indies"
  }
]
"""

ex1_output_data = """
[
 {
   "team": India,
   "playerCount": 2
 },
 {
   "team": "West Indies",
   "playerCount": 1
 }
]
"""

Invoke the Chat Completion API to generate the MongoDB Query

get_mongodb_query(ex1_input_data, ex1_output_data, GPT4_MODEL)

Below is the response from the AI assistant.

Assistant Response:

    {
        "mongoDBQuery": [
            {
                "$group": {
                    "_id": "$team",
                    "playerCount": { "$sum": 1 }
                }
            },
            {
                "$project": {
                    "_id": 0,
                    "team": "$_id",
                    "playerCount": 1
                }
            }
        ],
        "queryExplanation": "The query consists of two stages in the aggregation pipeline. The first stage is the `$group` stage, which groups the documents by the 'team' field. For each group, it counts the number of players using the `$sum` accumulator, which increments by 1 for each document in the group, resulting in a count of players per team. The second stage is the `$project` stage, which reshapes each document in the stream; the `_id` field is suppressed and the 'team' field is set to the value of `_id` from the `$group` stage. The 'playerCount' field is included as is."
    }

This is cool! But we still don’t know for sure, if this generated query will indeed produce the expected output. In order to be sure, we can help out the model by providing a tool to execute the query against the input data in a real MongoDB instance and verify the output against the expected output.

AI agent to generate and verify MongoDB Queries

So, we will take this to the next step, and create an AI agent and provide the agent with a custom tool using OpenAI’s function calling capability.

Specifically we will pass in a “verify” function as the tool that can be invoked by the model after it generates the query to test and verify the output produced by that query.

Basic Setup

Let’s do the basic set up as earlier and instantiate the OpenAI client.

!pip install python-dotenv
!pip install pymongo
!pip install termcolor 
!pip install openai
!pip install requests
!pip install json5

import json
import json5
import ast
from openai import OpenAI
import os
from termcolor import colored
from pymongo import MongoClient
from dotenv import load_dotenv
_ = load_dotenv()
openai_client = OpenAI()

GPT3_MODEL = "gpt-3.5-turbo-1106"
GPT4_MODEL = "gpt-4-1106-preview"

Utility function to print the chat conversation

def pretty_print_conversation(messages):
    role_to_color = {
        "system": "red",
        "user": "green",
        "assistant": "blue",
        "tool": "magenta"
    }
    
    for message in messages:
        if message["role"] == "system":
            print(colored(f"system: {message['content']}\n", role_to_color[message["role"]]))
        elif message["role"] == "user":
            print(colored(f"user: {message['content']}\n", role_to_color[message["role"]]))
        elif message["role"] == "assistant" and message.get("tool_calls"):
            print(colored(f"assistant: {message['tool_calls'][0]['function']}\n", role_to_color[message["role"]]))
        elif message["role"] == "assistant" and not message.get("tool_calls"):
            print(colored(f"assistant: {message['content']}\n", role_to_color[message["role"]]))
        elif message["role"] == "tool":
            print(colored(f"function ({message['name']}): {message['content']}\n", role_to_color[message["role"]]))

Define a function to work with Chat Completion API

This function below is used to interact with the Chat Completion API and return the response message along with finish reason.

def chat_completion(messages, tools=None, tool_choice=None, model=GPT3_MODEL):
    chat_completion = openai_client.chat.completions.create(
        model=GPT_MODEL,
        messages=messages,
        tools=tools,
        tool_choice=tool_choice,
        temperature=0,
        response_format={"type": "json_object"}
    )
    print(f"chat_completion: {chat_completion}")
    chat_completion_json = json.loads(chat_completion.model_dump_json())
    assistant_message = chat_completion_json["choices"][0]["message"]
    finish_reason = chat_completion_json["choices"][0]["finish_reason"]
    
    #This is done to avoid the error thrown by the Chat Completion API 
    #on subsequent calls to the API when using tools.
    if assistant_message.get("function_call") is None:
        assistant_message.pop("function_call", None)

    chat_response = {
        "assistant_message": assistant_message,
        "finish_reason": finish_reason
    }
    
    return chat_response

Define the function to be passed as a tool to the model

Next, we define the function that we want to pass as a tool to be used by the model as part of its processing and arriving at an output for the user query.

This function verifies the generated query, i.e., to execute the query against an actual MongoDB instance with the given input and the generated query, so that the actual output can be compared with the expected output to determine if the generated query actually works.

tools = [
    {
        "type": "function",
        "function": {
            "name": "verify_query",
            "description": f"""This function is used to verify that the MongoDB query produces the expected output data for the given user input. 
                            Input to this function should be a fully formed MongoDB query.""",
            "parameters": {
                "type": "object",
                "properties": {
                    "inputData": {
                        "type": "string",
                        "description": "The sample input in JSON format given by the user",
                    },
                    "expectedOutput": {
                        "type": "string",
                        "description": "The expected output data in JSON format given by the user",
                    },
                    "mongoDBQuery": {
                        "type": "string",
                        "description": "MongoDB aggregation pipeline stages array that produces the expected output for the given input data. \
                        This field corresponds to just the list of stages in the aggregation pipeline \
                        and shouldn't contain the `db.collection.aggregate` prefix"
                    }
                },
                "required": ["inputData", "expectedOutput", "mongoDBQuery"]
            }
        }
    }
]

Utility functions to connect and execute queries against a MongoDB instance

We will need to create a Database in MongoDB, so that we can connect to it and execute the generated query when the model decides to invoke the “verify_query” function as part of its processing.

The below are utility functions to connect to DB, clean up any existing data in the collection, insert the input data from the user in the collection and finally execute the generated query.

def get_database():
    CONNECTION_STRING = "mongodb+srv://..."
    DB_NAME="test"
    client = MongoClient(CONNECTION_STRING)
    return client[DB_NAME]

def insert_data(collection, data_array):
    if len(data_array) != 0 and collection is not None:
        collection.insert_many(data_array)

def delete_data(collection):
    if collection is not None:
        collection.delete_many({})

def execute_query(collection, aggregation_pipeline=None):
    if aggregation_pipeline:
        result = collection.aggregate(aggregation_pipeline)
    else:
        raise ValueError("Aggregation pipeline must be provided.")
    return list(result)

Function Calling

When the model decides to call the “verify_query” function and returns the parameters and the function name as its response, we will invoke the “execute_function_call” function to get the parameter details and the function name and call the appropriate function, in this case, the “verify_query” function.

The “verify_query” function is simple. It inserts the input data into a MongoDB collection and executes the generated query and compares the output against the expected output and returns a Boolean response.

The “execute_function_call” function then returns a success, failure or error message depending on the response from “verify_query” function.

def verify_query(input_data, expected_output, query): 
    db = get_database()
    collection = db["mqg"]
    
    #Remove any data present in the collection before inserting new data
    delete_data(collection)
    insert_data(collection, input_data)

    result_from_db = execute_query(collection, query)
    print(f"Result: {result_from_db}")

    # Sort the lists based on a canonical representation of each dictionary
    sorted_list1 = sorted(expected_output, key=lambda x: json.dumps(x, sort_keys=True))
    sorted_list2 = sorted(result_from_db, key=lambda x: json.dumps(x, sort_keys=True))
    
    return sorted_list1 == sorted_list2

def execute_function_call(message):
    success = { "result": "success", "message": "This query produces the expected output" }
    failure = { "result": "failure", "message": "This query doesn't produce the expected output" }
    error = { "result": "error" }
    
    if message["tool_calls"][0]["function"]["name"] == "verify_query":
        arguments = json.loads(message["tool_calls"][0]["function"]["arguments"])
        print(f"Assistant Generated Function Arguments: {arguments}")
        
        #Parsing the arguments that are strings to python lists
        input_data = ast.literal_eval(arguments["inputData"])
        expected_output = ast.literal_eval(arguments["expectedOutput"])
        query = json5.loads(arguments["mongoDBQuery"])
        
        result_bool = verify_query(input_data, expected_output, query)
        print(f"Results Match: {result_bool}")
            
        return success if result_bool else failure
    else:
        error.update({"message": f"Error: function {message['function_call']['name']} does not exist"})
        return error

System and User Prompt for the AI agent

The system prompt is the same one that we used earlier. We slightly adjust the user prompt to explicitly instruct the model to verify the generated query before generating the response.

system_prompt = f"""
You are a MongoDB expert with great expertise in writing MongoDB queries 
for any given data to produce an expected output.
"""

user_prompt = f"""
Below are your tasks to perform. Follow the instructions entirely without missing anything. 

1. Your task is to write a MongoDB Query, specifically an aggregation pipeline\
that would produce the expected output for the given input.

2. Verify that executing the generated query actually produces the output \
matching the given expected output.

3. The final response should always be returned in JSON with the following fields.

4. Do not make a call to the same tool twice in a single response. 
```
mongoDBQuery: The MongoDB aggregation pipeline to produce the expected output for a given input.\
This field corresponds to just the list of stages in the aggregation pipeline \
and shouldn't contain the "db.collection.aggregate" prefix.

queryExplanation: A detailed explanation for the query that was returned.
```

Input data: {input_data} 
Expected output data: {output_data}
"""

user_prompt_failure_scenario = f"""
The query that you provided as argument to the function call didn't produce the expected output. \
Try again to write a new query that would produce the expected output
"""

AI Agent execution

Here we try to create an agent behavior by first invoking the model and then check the output of the verify_query function to determine if we need to call the model again with a custom failure prompt or to generate the final result. We do this in a loop for a specified number of times instead of letting it run for ever until it arrives at the correct query, but we can definitely change that behavior as we wish.


counter = 1
max_loop = 3

messages = []
messages.append({"role": "system", "content": system_prompt})
messages.append({"role": "user", "content": user_prompt})

chat_response = chat_completion(
    messages=messages, 
    tools=tools, 
    tool_choice={"type": "function", "function": {"name": "verify_query"}}
)

assistant_message = chat_response["assistant_message"]
messages.append(assistant_message)

#pretty_print_conversation(messages)
#print(json.dumps(messages, indent=2))

while counter < max_loop:
    print(f"Counter: {counter}")
    counter += 1
    
    if assistant_message.get("tool_calls"):
        result = execute_function_call(assistant_message)
        print(f"Results after invoking the function: {result}")

        messages.append({
            "role": "tool", 
            "tool_call_id": assistant_message["tool_calls"][0]['id'], 
            "name": assistant_message["tool_calls"][0]["function"]["name"], 
            "content": result["message"]
        })

        if result["result"] == "failure":
            messages.append({"role": "user", "content": user_prompt_failure_scenario})
            chat_response = chat_completion(
                messages=messages, 
                tools=tools, 
                tool_choice={"type": "function", "function": {"name": "verify_query"}}
            )
        else:
            chat_response = chat_completion(messages, tools)
            
        assistant_message = chat_response["assistant_message"]
        messages.append(assistant_message)
        
        pretty_print_conversation(messages)
    else:
        break

if assistant_message.get("tool_calls"):
    arguments = json.loads(assistant_message["tool_calls"][0]["function"]["arguments"])
    mongoDBQuery = json5.loads(arguments["mongoDBQuery"])
    return_val = {
        "mongoDBQuery": mongoDBQuery,
        "queryExplanation": ""
    }
    print(f"The best response from the assistant: {return_val}")
else:
    print(f"Final Response from the assistant: {assistant_message['content']}")

Response from the AI agent

Below is the response from the AI agent when we invoked it using the same input and expected output data that we used earlier.

{
    "mongoDBQuery": [
        {"$group": {"_id": "$team", "playerCount": {"$sum": 1}}},
        {"$project": {"team": "$_id", "_id": 0, "playerCount": 1}}
    ],
    "queryExplanation": "The query consists of two stages in the aggregation pipeline. The first stage is a $group stage, which groups the documents by the 'team' field. For each group, it calculates the count of players by summing 1 for each document in the group, resulting in a 'playerCount' field. The second stage is a $project stage, which reshapes each document in the stream; it includes the 'team' field (renamed from the grouping '_id' field) and the 'playerCount' field, while excluding the default '_id' field."
}

Code

The code for these examples are in separate Jupyter notebooks and the github links are given below.

MongoDB query with OpenAI

MongoDB query with OpenAI function calling

Summary

To summarize we were able to utilize Chat Completion API along with function calling capability to generate MongoDB queries and verify them as well.

We can easily create a UI around it to extend the example and make it more useful!

Thanks for reading and happy coding!

Leave a Comment

Your email address will not be published. Required fields are marked *