Langchain Integration
Using the API with Langchain
Many applications will use the Langchain library. The following is how the SDSC LLM service should be integrated with Langchain. Essentially in order to communicate with the LLM via api, and have langchain integration we create a custom wrapper class.
Replace <api_key> with the actual key.
import requests
from langchain_core.prompts import PromptTemplate
class CustomSDSCLLM:
def __init__(self, api_key: str, model: str, base_url: str = "https://sdsc-llm-openwebui.nrp-nautilus.io/api/chat/completions"):
self.api_key = api_key
self.model = model
self.base_url = base_url
def call(self, prompt: str) -> str:
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
}
payload = {
"model": self.model,
"messages": [
{
"role": "system",
"content": "You are a helpful assistant."
},
{
"role": "user",
"content": prompt
}
],
"stream": False
}
response = requests.post(self.base_url, headers=headers, json=payload)
if response.status_code == 200:
response_data = response.json()
return response_data['choices'][0]['message']['content'].strip()
else:
raise Exception(f"Request failed with status {response.status_code}: {response.text}")
def test_langchain(question: str, model: str, api_key: str):
llm = CustomSDSCLLM(api_key=api_key, model=model)
prompt = PromptTemplate(
template="""<|begin_of_text|><|start_header_id|>system<|end_header_id|> You are a helpful assistant answering
a user question.
<|eot_id|><|start_header_id|>user<|end_header_id|>
Here is the user question: {question} \n <|eot_id|><|start_header_id|>assistant<|end_header_id|>
""",
input_variables=["question"]
)
# Create the formatted prompt
formatted_prompt = prompt.format(question=question)
# Call the LLM
return llm.call(formatted_prompt)
if __name__ == "__main__":
api_key = "<api_key>" # Replace with your actual API key
llama3_output = test_langchain("Solve 1 + 1", "llama3-sdsc", api_key)
gemma2_output = test_langchain("Solve 1 + 1", "gemma2", api_key)
print("llama3-sdsc: " + llama3_output.strip())
print("-" * 20)
print("gemma2: " + gemma2_output.strip())