Model Templates

Refer below templates to create your own model python file. (.py file).

Notes:

  • Ensure the model weight path is /opt/models/mlflowmodel/modelweight in the .py files.

  • Below section is mandatory in any .py file:

    import mlflow
    import os
    tracking_uri = os.getenv('MLFLOW_TRACKING_URL')
    mlflow.set_tracking_uri(tracking_uri)
    

Template For Completion Model (example: BART)

#!/usr/bin/env python
# coding: utf-8

import torch
from datasets import Dataset, load_dataset

MODEL_NAME = "philschmid/bart-large-cnn-samsum"

from transformers import (
BartForConditionalGeneration,
BartTokenizer,
BitsAndBytesConfig,
TrainingArguments,
)

model_path = "/opt/models/mlflowmodel/modelweight"

# Load model
model = BartForConditionalGeneration.from_pretrained(model_path)
tokenizer = BartTokenizer.from_pretrained(model_path)


DEFAULT_SYSTEM_PROMPT = """

""".strip()

def generate_prompt_inference(
statement: str, system_prompt: str  = DEFAULT_SYSTEM_PROMPT
) -> str:
return f"""### Instruction: {system_prompt}

### Input:
{statement.strip()}

### Response:
""".strip()


import mlflow
from mlflow.models import infer_signature

# Define model signature including params
input_example = {"prompt": generate_prompt_inference("India has 28 states and 8 union territories with different cultures and is the most populated country in the world.[17] The Indian culture, often labeled as an amalgamation of several various cultures, spans across the Indian subcontinent and has been influenced and shaped by a history that is several thousand years old. ")}
inference_config = {
"temperature": 1.0,
"max_new_tokens": 100,
"do_sample": True,
}

## #Define model  signature, input and output schema with prams
signature = infer_signature(
model_input=input_example,
model_output="SQL generated is ...",
params=inference_config
)


import mlflow
import os
tracking_uri = os.getenv('MLFLOW_TRACKING_URL')
mlflow.set_tracking_uri(tracking_uri)

## Register the model in mlflow, provide artifact information such as task type, pip requirement with version, input example and any additional metadata
with mlflow.start_run(run_name="bart-large-cnn-samsum_completions") as run:
mlflow.transformers.log_model(
transformers_model={
"model": model,
"tokenizer": tokenizer,
},
task="summarization",
artifact_path="model",
pip_requirements=["torch", "transformers"],
input_example=input_example,
signature=signature,
# Add the metadata task so that the model serving endpoint created later will be optimized
metadata={"task": "llm/v1/completions"},
registered_model_name="bart-large-cnn-samsum_completions"
)
model_uri = mlflow.get_artifact_uri("model")
mlflow.end_run()

Template for EMBEDDING MODEL (example: BERT (MPNET))

#!/usr/bin/env python
# coding: utf-8

from sentence_transformers import SentenceTransformer
import mlflow

# register in MLFlow

import mlflow
import os
tracking_uri = os.getenv('MLFLOW_TRACKING_URL')
mlflow.set_tracking_uri(tracking_uri)

Registered_model = "all-mpnet-base-v2_embedding"

tags={"gathr": "foundation_model"}
description="This is a sentence-transformers model: It maps sentences & paragraphs to a 768 dimensional dense vector space and can be used for tasks like clustering or semantic search."


# Define Model Path to load the model

modelPath = "/opt/models/mlflowmodel/modelweight"
model = SentenceTransformer(modelPath)


# Register model in mlflow

sentences = ["This is an example sentence", "Each sentence is converted"]
# Define model  signature, input and output schema with prams
signature = mlflow.models.infer_signature( 
    model_input=sentences,
    model_output=model.encode(sentences),
)

## Register the model in mlflow, provide artifact information such as task type, pip requirement with version, input example and any additional metadata

with mlflow.start_run(run_name="all-mpnet-base-v2_embedding",tags=tags,description=description) as run:
    mlflow.sentence_transformers.log_model(
        model=model,
        artifact_path="all_mpnet_base_v2_embedding",
        pip_requirements=["torch==2.2.1", "transformers==4.37.1", "accelerate==0.26.1", "sentence-transformers== 2.5.1"],
        signature=signature,
        input_example=sentences,
        metadata={"task": "llm/v1/embedding"},
        registered_model_name="all-mpnet-base-v2_embedding"
    )
    model_uri = mlflow.get_artifact_uri()
mlflow.end_run()
Top