OpenAI Spyre Inference

Source examples/online_inference/openai_spyre_inference.py.

""" 
This example shows how to use Spyre with vLLM for running online inference.

First, start the server with the following command:
    python3 -m vllm.entrypoints.openai.api_server \
        --model /models/llama-7b-chat/ \
        --max-model-len=2048 \
        --block-size=2048

By default, the server will use a batch size of 1, a max prompt length of 64 
tokens, and a max of 20 decode tokens.

You can change these with the env variables VLLM_SPYRE_WARMUP_BATCH_SIZES, 
VLLM_SPYRE_WARMUP_PROMPT_LENS, and VLLM_SPYRE_WARMUP_NEW_TOKENS.
"""

import argparse
import time

from openai import OpenAI

parser = argparse.ArgumentParser(
    description="Script to submit an inference request to vllm server.")

parser.add_argument(
    "--max_tokens",
    type=int,
    default=20,
    help="Maximum tokens. Must match VLLM_SPYRE_WARMUP_NEW_TOKENS",
)
parser.add_argument(
    "--batch_size",
    type=int,
    default=1,
)
parser.add_argument(
    "--num_prompts",
    type=int,
    default=3,
)
parser.add_argument(
    "--stream",
    action=argparse.BooleanOptionalAction,
    help="Whether to stream the response.",
)

args = parser.parse_args()

# Modify OpenAI's API key and API base to use vLLM's API server.
openai_api_key = "EMPTY"
openai_api_base = "http://localhost:8000/v1"

client = OpenAI(
    # defaults to os.environ.get("OPENAI_API_KEY")
    api_key=openai_api_key,
    base_url=openai_api_base,
)

models = client.models.list()
model = models.data[0].id

template = (
    "Below is an instruction that describes a task. Write a response that "
    "appropriately completes the request. Be polite in your response to the "
    "user.\n\n### Instruction:\n{}\n\n### Response:")

instructions = [
    "Provide a list of instructions for preparing chicken soup for a family" + \
        " of four.",
    "Please compare New York City and Zurich and provide a list of" + \
        " attractions for each city.",
    "Provide detailed instructions for preparing asparagus soup for a" + \
        " family of four.",
]

prompts = [template.format(instr) for instr in instructions]
prompts = prompts * (args.num_prompts // len(prompts) + 1)
prompts = prompts[0:args.num_prompts]

# This batch size must match VLLM_SPYRE_WARMUP_BATCH_SIZES
batch_size = args.batch_size
print('submitting prompts of batch size', batch_size)

# making sure not to submit more prompts than the batch size
for i in range(0, len(prompts), batch_size):
    prompt = prompts[i:i + batch_size]

    stream = args.stream

    print(f"Prompt: {prompt}")
    start_t = time.time()

    completion = client.completions.create(model=model,
                                           prompt=prompt,
                                           echo=False,
                                           n=1,
                                           stream=stream,
                                           temperature=0.0,
                                           max_tokens=args.max_tokens)

    end_t = time.time()
    print("Results:")
    if stream:
        for c in completion:
            print(c)
    else:
        print(completion)

    total_t = end_t - start_t
    print(f"Duration: {total_t}s")