Roughly 350 tokens is estimated where it starts behaving strange. Before this threshold 3 gpu’s are at about 33% on average while predicting. Anywhere after, one gpu is at 99% and the rest at 0-3% using nvidia-smi to watch usage and eventually OOM’s. There is way too much to post a working version. Here is the psuedoish code below where any changes would be made.
# Set maxsplit size to avoid fragmentation, necessary to run sqlcoder34b on 2 gpu's
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:512"
def load_defog_model(
checkpoint, do_sample=False, repetition_penalty=1.1, load_in_8bit=False
):
"""Load defog model, tokenizer, and eos token id
Args:
checkpoint (str): huggingface model id, e.g. "defog/sqlcoder34B", or else path to locally save model
do_sample (bool, optional): whether to sample tokens during generations. Defaults to False.
repetition_penalty (float, optional): how much to penalize model for repetition. Defaults to 1.1.
load_in_8bit (bool, optional): whether to load model in 8bit. If True, reduces memory consumption but slows inference. Defaults to False.
Returns:
transformers model: the defog sql generator model,
transformers tokenizer: the model's corresponding tokenizer,
int: eos token id, id of a special token we will use to represent the end of a sentence
"""
tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint)
eos_token_id = tokenizer.convert_tokens_to_ids(["```"])[0]
model = transformers.AutoModelForCausalLM.from_pretrained(
checkpoint,
device_map="auto",
do_sample=do_sample,
repetition_penalty=repetition_penalty,
load_in_8bit=load_in_8bit,
torch_dtype=torch.bfloat16,
num_beams=5,
)
return model, tokenizer, eos_token_id
def load_defog_chain(
prompt, checkpoint, do_sample=False, repetition_penalty=1.1, load_in_8bit=False
):
"""Create a pipeline to send requests to local defog model and get back responses
Args:
prompt (PromptTemplate): natural language to sql prompt
checkpoint (str): huggingface model id, e.g. "defog/sqlcoder34B", or else path to locally save model
do_sample (bool, optional): whether to sample tokens during generations. Defaults to False.
repetition_penalty (float, optional): how much to penalize model for repetition. Defaults to 1.1.
load_in_8bit (bool, optional): whether to load model in 8bit. If True, reduces memory consumption but slows inference. Defaults to False.
Returns:
Langchain pipeline: run inference with invoke method
"""
model, tokenizer, eos_token_id = load_defog_model(
checkpoint,
do_sample=do_sample,
repetition_penalty=repetition_penalty,
load_in_8bit=load_in_8bit,
)
pipe = transformers.pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
max_new_tokens=300,
num_return_sequences=1,
eos_token_id=eos_token_id,
pad_token_id=eos_token_id,
)
hf_pipe = HuggingFacePipeline(pipeline=pipe)
chain = prompt | hf_pipe
return chain
sql_query = sql_chain.invoke(
{"question": question, "table_schema": table_schema, "prior_conv": total_prompt}