I am trying to parse a Stardew Valley CSV, embed that into ChaptGPT, and have ChatGPT answer questions about the data. The problem is that my responses I get from ChatGPT are not accurate. I am still learning to use langchain and Chroma so I am unsure if there is something wrong with my code or its on my end.
Question | Response | Correct |
---|---|---|
What sells for the most between a pumpkin or rare seed? | The rare seed sells for more than the pumpkin | Yes |
What crop sell for them most in the fall? | The crop that sells for the most in Fall is the pumpkin, with a selling price of 320 | No |
What are all the crops in the fall? | The crops in fall are Corn, Pumpkin, and Cranberry. | No (there are more) |
import os
import sys
import constants
# import langchain loaders
from langchain_openai import ChatOpenAI
from langchain_community.document_loaders.csv_loader import CSVLoader
from langchain_openai import OpenAIEmbeddings
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnableLambda, RunnablePassthrough
from langchain_community.vectorstores import Chroma
os.environ["OPENAI_API_KEY"] = constants.OPENAI_API_KEY
embedding_function = OpenAIEmbeddings()
loader = CSVLoader(file_path="./stardew-crops.csv", csv_args={
'delimiter': ',',
'fieldnames': ['Season', 'Crop', 'Buy', 'Sell', 'Grow', 'Regr', 'Yield', 'Profit/D', 'Profit/M', 'Vendor', 'Notes']
})
documents = loader.load()
db = Chroma.from_documents(documents, embedding_function)
retriever = db.as_retriever()
template = """You are a Stardew Valley crop specialist.
{context}
Question: {question}
"""
prompt = ChatPromptTemplate.from_template(template)
model = ChatOpenAI()
chain = (
{"context": retriever, "question": RunnablePassthrough()}
| prompt
| model
| StrOutputParser()
)
# Continuous querying loop
while True:
query = input("Please enter your query (or type 'exit' to quit): ")
if query.lower() == 'exit':
print("Exiting...")
break
# Invoke the chain with the current query and print the response
response = chain.invoke(query)
print(response)
Here are a few rows from the CSV:
Fall,Fairy Rose,200,290,12,0,1.0,7.50,180,Pierre,
Fall,Grape,60,80,10,3,1.0,2.00,420,Pierre,
Fall,Pumpkin,100,320,13,0,1.0,16.92,440,Pierre,
Fall,Rare Seed,1000,3000,24,0,1.0,83.33,2000,Gypsy,Receive Sweet Gem Berry