IMDBot documentation

This code creates a chatbot called IMDBot, which answers questions about movies based on data from a CSV file. The chatbot is powered by OctoAI and is hosted on a Streamlit application.

Indexing and Query flows:

Setting up Environment Variables

The setup_env_variables function sets up environment variables OCTOAI_API_TOKEN and ENDPOINT_URL using the secrets from Streamlit.

def setup_env_variables():
    os.environ["OCTOAI_API_TOKEN"] = st.secrets['OCTOAI_API_TOKEN']
    os.environ["ENDPOINT_URL"] = st.secrets['ENDPOINT_URL']

Initializing Session State

handle_session_state function sets up the initial session state variables.

def handle_session_state():
    st.session_state.setdefault('generated', [])
    st.session_state.setdefault('past', [])
    st.session_state.setdefault('q_count', 0)

Loading Movie Data

The load_data function loads movie data from a CSV file using a PagedCSVReader loader.

def load_data(file_path):
    PagedCSVReader = download_loader("PagedCSVReader")
    loader = PagedCSVReader()
    return loader.load_data(file_path)

Initializing OctoAIEndpoint and LangChain LLMPredictor

The initialize_llm function initializes the OctoAIEndpoint and LLMPredictor.

def initialize_llm(endpoint_url):
    """Initialize the OctoAiCloudLLM and LLMPredictor."""
    llm = OctoAIEndpoint(
            "model": "llama-2-7b-chat-fp16",
            "messages": [
                    "role": "system",
                    "content": "Below is an instruction that describes a task. Write a response that appropriately completes the request.",
            "stream": False,
            "max_tokens": 256,
    return LLMPredictor(llm=llm)

Creating LangchainEmbedding

The create_embeddings function creates an instance of LangchainEmbedding using OctoAIEmbeddings wrapper for a hosted Instructor-Large model endpoint.

def create_embeddings():
    if 'embeddings' not in st.session_state:
        embeddings = LangchainEmbedding(OctoAIEmbeddings(
        st.session_state['embeddings'] = embeddings
    return st.session_state['embeddings']

Creating ServiceContext

The create_service_context function creates an instance of llama_index ServiceContext.

def create_service_context(llm_predictor, embeddings):
    if 'service_context' not in st.session_state:
        service_context = ServiceContext.from_defaults(llm_predictor=llm_predictor, chunk_size_limit=400, embed_model=embeddings)
        st.session_state['service_context'] = service_context
    return st.session_state['service_context']

Creating GPTVectorStoreIndex

The create_index function creates an instance of llama_index GPTVectorStoreIndex. It first checks if the index.pkl file exists and loads it if available. If not, it creates a new GPTVectorStoreIndex from the given documents using the provided service_context.

def create_index(documents, service_context):
    if 'index' not in st.session_state:
        path = Path("index.pkl")
        if path.exists():
            index = dill.load(open(path, "rb"))
            index = GPTVectorStoreIndex.from_documents(
                documents, service_context=service_context)
            #dill.dump(index, open(path, "wb")) #
        st.session_state['index'] = index
    return st.session_state['index']

Creating Query Engine

The create_query_engine function creates a llama_index query engine using the given index and llm_predictor.

def create_query_engine(index, llm_predictor):
    if 'query_engine' not in st.session_state:
        query_engine = index.as_query_engine(
            verbose=True, llm_predictor=llm_predictor)
        st.session_state['query_engine'] = query_engine
    return st.session_state['query_engine']

Processing Query

The query function processes a query and returns a response. It first gets a response from the llama_index query_engine and then transforms the response into a string.

def query(payload, query_engine):
    response = query_engine.query(payload["inputs"]["text"])
    # Transform response to string and remove leading newline character if present
    return str(response).lstrip("\n")

Main Function

The main function initializes the environment, sets up the endpoint URL, loads the data, initializes the LLM predictor, and creates the embeddings, service context, index, and query engine. It then displays the user interface and processes the user’s input.

def main():
    # Setup the environment variables
    # Set the endpoint url
    endpoint_url = os.getenv("ENDPOINT_URL")
    # Initialize the session state
    # Load the data
    documents = load_data(Path('rotten_tomatoes_top_movies.csv'))
    # Initialize the LLM predictor
    llm_predictor = initialize_llm(endpoint_url)
    # Create the embeddings
    embeddings = create_embeddings()
    # Create the service context
    service_context = create_service_context(llm_predictor, embeddings)
    # Create the index
    index = create_index(documents, service_context)
    # Create the query engine
    query_engine = create_query_engine(index, llm_predictor)
    # Display the header
    st.subheader("🎬  IMDBot - Powered by Oct

The main function continues by setting up the user interface and getting the user’s input. It then processes the user’s input and displays the generated response on the user interface.

def main():
    # ... continued from before ...

    st.subheader("🎬  IMDBot - Powered by OctoAI")
    st.markdown('* :movie_camera: Tip #1: IMDBot is great at answering factual questions like: "Who starred in the Harry Potter movies?" or "What year did Jaws come out?')
    st.markdown('* :black_nib: Tip #2: IMDBot loves the word "synopsis" -- we suggest using it if you are looking for plot summaries. Otherwise, expect some hallucinations.')
    st.markdown("* :blush: Tip #3: IMDbot has information about 500 popular movies, but is not comprehensive. It probably won't know some more obscure films.")
    st.markdown("### Welcome to the IMDBot demo")
    st.sidebar.image("octoml-octo-ai-logo-color.png", caption="Try OctoML's new compute service for free by signing up for early access:")

        # Get the user input
        user_input = get_text(q_count=st.session_state['q_count'])
        # If user input is not empty, process the input
        if user_input and user_input.strip() != '':
            output = query({"inputs": {"text": user_input, }}, query_engine)
            # Increment q_count, append user input and generated output to session state
            st.session_state['q_count'] += 1
            if output:
        # If there are generated messages, display them
        if st.session_state['generated']:
            for i in range(len(st.session_state['generated'])-1, -1, -1):
                message(st.session_state['past'][i], is_user=True, key=f'{str(i)}_user')
                message(st.session_state["generated"][i], key=str(i))

    except Exception as e:
        st.error("Something went wrong. Please try again.")

if __name__ == "__main__":

In the exception handling part, if an error occurs during the execution, an error message “Something went wrong. Please try again.” is displayed. The main function is called if this script is run as the main module.