Skip to content

Langchain Integration

Using the API with Langchain

Many applications will use the Langchain library. The following is how the SDSC LLM service should be integrated with Langchain. Essentially in order to communicate with the LLM via api, and have langchain integration we create a custom wrapper class.

Replace <api_key> with the actual key.

import requests
from langchain_core.prompts import PromptTemplate

class CustomSDSCLLM:
    def __init__(self, api_key: str, model: str, base_url: str = "https://sdsc-llm-openwebui.nrp-nautilus.io/api/chat/completions"):
        self.api_key = api_key
        self.model = model
        self.base_url = base_url

    def call(self, prompt: str) -> str:
        headers = {
            "Authorization": f"Bearer {self.api_key}",
            "Content-Type": "application/json",
        }
        payload = {
            "model": self.model,
            "messages": [
                {
                    "role": "system",
                    "content": "You are a helpful assistant."
                },
                {
                    "role": "user",
                    "content": prompt
                }
            ],
            "stream": False
        }

        response = requests.post(self.base_url, headers=headers, json=payload)

        if response.status_code == 200:
            response_data = response.json()
            return response_data['choices'][0]['message']['content'].strip()
        else:
            raise Exception(f"Request failed with status {response.status_code}: {response.text}")

def test_langchain(question: str, model: str, api_key: str):
    llm = CustomSDSCLLM(api_key=api_key, model=model)

    prompt = PromptTemplate(
        template="""<|begin_of_text|><|start_header_id|>system<|end_header_id|> You are a helpful assistant answering
        a user question.
         <|eot_id|><|start_header_id|>user<|end_header_id|>
        Here is the user question: {question} \n <|eot_id|><|start_header_id|>assistant<|end_header_id|>
        """,
        input_variables=["question"]
    )

    # Create the formatted prompt
    formatted_prompt = prompt.format(question=question)

    # Call the LLM
    return llm.call(formatted_prompt)

if __name__ == "__main__":
    api_key = "<api_key>"  # Replace with your actual API key
    llama3_output = test_langchain("Solve 1 + 1", "llama3-sdsc", api_key)
    gemma2_output = test_langchain("Solve 1 + 1", "gemma2", api_key)

    print("llama3-sdsc: " + llama3_output.strip())
    print("-" * 20)
    print("gemma2: " + gemma2_output.strip())