Model Templates
Refer below templates to create your own model python file. (.py file).
Notes:
Ensure the model weight path is
/opt/models/mlflowmodel/modelweight
in the .py files.Below section is mandatory in any .py file:
import mlflow import os tracking_uri = os.getenv('MLFLOW_TRACKING_URL') mlflow.set_tracking_uri(tracking_uri)
Template For Completion Model (example: BART)
#!/usr/bin/env python
# coding: utf-8
import torch
from datasets import Dataset, load_dataset
MODEL_NAME = "philschmid/bart-large-cnn-samsum"
from transformers import (
BartForConditionalGeneration,
BartTokenizer,
BitsAndBytesConfig,
TrainingArguments,
)
model_path = "/opt/models/mlflowmodel/modelweight"
# Load model
model = BartForConditionalGeneration.from_pretrained(model_path)
tokenizer = BartTokenizer.from_pretrained(model_path)
DEFAULT_SYSTEM_PROMPT = """
""".strip()
def generate_prompt_inference(
statement: str, system_prompt: str = DEFAULT_SYSTEM_PROMPT
) -> str:
return f"""### Instruction: {system_prompt}
### Input:
{statement.strip()}
### Response:
""".strip()
import mlflow
from mlflow.models import infer_signature
# Define model signature including params
input_example = {"prompt": generate_prompt_inference("India has 28 states and 8 union territories with different cultures and is the most populated country in the world.[17] The Indian culture, often labeled as an amalgamation of several various cultures, spans across the Indian subcontinent and has been influenced and shaped by a history that is several thousand years old. ")}
inference_config = {
"temperature": 1.0,
"max_new_tokens": 100,
"do_sample": True,
}
## #Define model signature, input and output schema with prams
signature = infer_signature(
model_input=input_example,
model_output="SQL generated is ...",
params=inference_config
)
import mlflow
import os
tracking_uri = os.getenv('MLFLOW_TRACKING_URL')
mlflow.set_tracking_uri(tracking_uri)
## Register the model in mlflow, provide artifact information such as task type, pip requirement with version, input example and any additional metadata
with mlflow.start_run(run_name="bart-large-cnn-samsum_completions") as run:
mlflow.transformers.log_model(
transformers_model={
"model": model,
"tokenizer": tokenizer,
},
task="summarization",
artifact_path="model",
pip_requirements=["torch", "transformers"],
input_example=input_example,
signature=signature,
# Add the metadata task so that the model serving endpoint created later will be optimized
metadata={"task": "llm/v1/completions"},
registered_model_name="bart-large-cnn-samsum_completions"
)
model_uri = mlflow.get_artifact_uri("model")
mlflow.end_run()
Template for EMBEDDING MODEL (example: BERT (MPNET))
#!/usr/bin/env python
# coding: utf-8
from sentence_transformers import SentenceTransformer
import mlflow
# register in MLFlow
import mlflow
import os
tracking_uri = os.getenv('MLFLOW_TRACKING_URL')
mlflow.set_tracking_uri(tracking_uri)
Registered_model = "all-mpnet-base-v2_embedding"
tags={"gathr": "foundation_model"}
description="This is a sentence-transformers model: It maps sentences & paragraphs to a 768 dimensional dense vector space and can be used for tasks like clustering or semantic search."
# Define Model Path to load the model
modelPath = "/opt/models/mlflowmodel/modelweight"
model = SentenceTransformer(modelPath)
# Register model in mlflow
sentences = ["This is an example sentence", "Each sentence is converted"]
# Define model signature, input and output schema with prams
signature = mlflow.models.infer_signature(
model_input=sentences,
model_output=model.encode(sentences),
)
## Register the model in mlflow, provide artifact information such as task type, pip requirement with version, input example and any additional metadata
with mlflow.start_run(run_name="all-mpnet-base-v2_embedding",tags=tags,description=description) as run:
mlflow.sentence_transformers.log_model(
model=model,
artifact_path="all_mpnet_base_v2_embedding",
pip_requirements=["torch==2.2.1", "transformers==4.37.1", "accelerate==0.26.1", "sentence-transformers== 2.5.1"],
signature=signature,
input_example=sentences,
metadata={"task": "llm/v1/embedding"},
registered_model_name="all-mpnet-base-v2_embedding"
)
model_uri = mlflow.get_artifact_uri()
mlflow.end_run()
If you have any feedback on Gathr documentation, please email us!