Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: add MetricsRecorder
Browse files Browse the repository at this point in the history
McPatate committed Oct 31, 2024
1 parent a1d85f1 commit 796affb
Showing 5 changed files with 122 additions and 57 deletions.
1 change: 0 additions & 1 deletion .github/workflows/benchmark.yml
Original file line number Diff line number Diff line change
@@ -63,7 +63,6 @@ jobs:
commit_id=$GITHUB_SHA
fi
commit_msg=$(git show -s --format=%s | cut -c1-70)
df -h
python3 benchmark/benchmarks_entrypoint.py "${{ github.head_ref || github.ref_name }}" "$commit_id" "$commit_msg"
env:
HF_TOKEN: ${{ secrets.HF_HUB_READ_TOKEN }}
34 changes: 32 additions & 2 deletions benchmark/README.md
Original file line number Diff line number Diff line change
@@ -7,9 +7,39 @@ You will need to define a python function named `run_benchmark` in your python f
The expected function signature is the following:

```py
def run_benchmark(logger: Logger, branch: str, commit_id: str, commit_msg: str, num_tokens_to_generate=100):
def run_benchmark(logger: Logger, metrics_recorder: MetricsRecorder, num_tokens_to_generate=100):
```

## Writing metrics to the database

TODO
`MetricRecorder` is thread-safe, in the sense of the python [`Thread`](https://docs.python.org/3/library/threading.html#threading.Thread). This means you can start a background thread to do the readings on the device measurements while not blocking the main thread to execute the model measurements.

cf [`llama.py`](./llama.py) to see an example of this in practice.

```py
def run_benchmark(logger: Logger, metrics_recorder: MetricsRecorder, num_tokens_to_generate=100):
benchmark_id = metrics_recorder.initialise_benchmark({"gpu_name": gpu_name, "model_id": model_id})
# To collect device measurements
metrics_recorder.collect_device_measurements(
benchmark_id, cpu_util, mem_megabytes, gpu_util, gpu_mem_megabytes
)
# To collect your model measurements
metrics_recorder.collect_model_measurements(
benchmark_id,
{
"model_load_time": model_load_time,
"first_eager_forward_pass_time_secs": first_eager_fwd_pass_time,
"second_eager_forward_pass_time_secs": second_eager_fwd_pass_time,
"first_eager_generate_time_secs": first_eager_generate_time,
"second_eager_generate_time_secs": second_eager_generate_time,
"time_to_first_token_secs": time_to_first_token,
"time_to_second_token_secs": time_to_second_token,
"time_to_third_token_secs": time_to_third_token,
"time_to_next_token_mean_secs": mean_time_to_next_token,
"first_compile_generate_time_secs": first_compile_generate_time,
"second_compile_generate_time_secs": second_compile_generate_time,
"third_compile_generate_time_secs": third_compile_generate_time,
"fourth_compile_generate_time_secs": fourth_compile_generate_time,
},
)
```
57 changes: 57 additions & 0 deletions benchmark/benchmarks_entrypoint.py
Original file line number Diff line number Diff line change
@@ -2,13 +2,69 @@
import importlib.util
import logging
import os
from typing import Dict
import psycopg2
import sys

from psycopg2.extras import Json
from psycopg2.extensions import register_adapter


register_adapter(dict, Json)


class ImportModuleException(Exception):
pass


class MetricsRecorder:
def __init__(self, connection, branch: str, commit_id: str, commit_msg: str):
self.conn = connection
self.branch = branch
self.commit_id = commit_id
self.commit_msg = commit_msg

def initialise_benchmark(self, metadata: Dict[str, str]) -> int:
"""
Creates a new benchmark, returns the benchmark id
"""
# gpu_name: str, model_id: str
with self.conn.cursor() as cur:
cur.execute(
"INSERT INTO benchmarks (branch, commit_id, commit_message, metadata) VALUES (%s, %s, %s, %s) RETURNING benchmark_id",
(self.branch, self.commit_id, self.commit_msg, metadata),
)
return cur.fetchone()[0]


def collect_device_measurements(self, benchmark_id: int, cpu_util, mem_megabytes, gpu_util, gpu_mem_megabytes):
"""
Collect device metrics, such as CPU & GPU usage. These are "static", as in you cannot pass arbitrary arguments to the function.
"""
with self.conn.cursor() as cur:
cur.execute(
"INSERT INTO device_measurements (benchmark_id, cpu_util, mem_megabytes, gpu_util, gpu_mem_megabytes) VALUES (%s, %s, %s, %s, %s)",
(benchmark_id, cpu_util, mem_megabytes, gpu_util, gpu_mem_megabytes),
)

def collect_model_measurements(self, benchmark_id: int, measurements: Dict[str, float]):
with self.conn.cursor() as cur:
cur.execute(
"""
INSERT INTO model_measurements (
benchmark_id,
measurements
) VALUES (%s, %s)
""",
(
benchmark_id,
measurements,
),
)

def close(self):
self.conn.close()

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

@@ -64,6 +120,7 @@ def import_from_path(module_name, file_path):

branch, commit_id, commit_msg = parse_arguments()

metrics_recorder = MetricsRecorder(psycopg2.connect("dbname=metrics"))
for entry in os.scandir(benchmarks_folder_path):
try:
if not entry.name.endswith(".py"):
2 changes: 1 addition & 1 deletion benchmark/init_db.sql
Original file line number Diff line number Diff line change
@@ -3,7 +3,7 @@ CREATE TABLE IF NOT EXISTS benchmarks (
branch VARCHAR(255),
commit_id VARCHAR(72),
commit_message VARCHAR(70),
gpu_name VARCHAR(255),
metadata jsonb,
created_at timestamp without time zone NOT NULL DEFAULT (current_timestamp AT TIME ZONE 'UTC')
);

85 changes: 32 additions & 53 deletions benchmark/llama.py
Original file line number Diff line number Diff line change
@@ -3,83 +3,71 @@
from threading import Event, Thread
from time import perf_counter, sleep
from typing import Optional
from benchmark.benchmarks_entrypoint import MetricsRecorder
import gpustat
import psutil
import psycopg2
import torch

from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig, StaticCache
from psycopg2.extras import Json
from psycopg2.extensions import register_adapter


os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"

os.environ["TOKENIZERS_PARALLELISM"] = "1"
torch.set_float32_matmul_precision("high")
register_adapter(dict, Json)


def collect_metrics(benchmark_id, continue_metric_collection):
def collect_metrics(benchmark_id, continue_metric_collection, metrics_recorder):
p = psutil.Process(os.getpid())
conn = psycopg2.connect("dbname=metrics")
cur = conn.cursor()
while not continue_metric_collection.is_set():
with p.oneshot():
cpu_util = p.cpu_percent()
mem_megabytes = p.memory_info().rss / (1024 * 1024)
gpu_stats = gpustat.GPUStatCollection.new_query()
gpu_util = gpu_stats[0]["utilization.gpu"]
gpu_mem_megabytes = gpu_stats[0]["memory.used"]
cur.execute(
"INSERT INTO device_measurements (benchmark_id, cpu_util, mem_megabytes, gpu_util, gpu_mem_megabytes) VALUES (%s, %s, %s, %s, %s)",
(benchmark_id, cpu_util, mem_megabytes, gpu_util, gpu_mem_megabytes),
metrics_recorder.collect_device_measurements(
benchmark_id, cpu_util, mem_megabytes, gpu_util, gpu_mem_megabytes
)
sleep(0.01)
conn.commit()
conn.close()


def run_benchmark(logger: Logger, branch: str, commit_id: str, commit_msg: str, num_tokens_to_generate=100):
def run_benchmark(logger: Logger, metrics_recorder: MetricsRecorder, num_tokens_to_generate=100):
continue_metric_collection = Event()
metrics_thread = None
model_id = "meta-llama/Llama-2-7b-hf"
try:
gpu_stats = gpustat.GPUStatCollection.new_query()
gpu_name = gpu_stats[0]["name"]
conn = psycopg2.connect("dbname=metrics")
cur = conn.cursor()
cur.execute(
"INSERT INTO benchmarks (branch, commit_id, commit_message, gpu_name) VALUES (%s, %s, %s, %s) RETURNING benchmark_id",
(branch, commit_id, commit_msg, gpu_name),
benchmark_id = metrics_recorder.initialise_benchmark({"gpu_name": gpu_name, "model_id": model_id})
logger.info(f"running benchmark #{benchmark_id} on {gpu_name} for {model_id}")
metrics_thread = Thread(
target=collect_metrics,
args=[benchmark_id, continue_metric_collection, metrics_recorder],
)
conn.commit()
benchmark_id = cur.fetchone()[0]
logger.info(f"running benchmark #{benchmark_id} on {gpu_name}")
metrics_thread = Thread(target=collect_metrics, args=[benchmark_id, continue_metric_collection])
metrics_thread.start()
logger.info("started background thread to fetch device metrics")

os.environ["TOKENIZERS_PARALLELISM"] = "false" # silence warnings when compiling

device = "cuda"
ckpt = "meta-llama/Llama-2-7b-hf"

logger.info("downloading weights")
# This is to avoid counting download in model load time measurement
model = AutoModelForCausalLM.from_pretrained(ckpt, torch_dtype=torch.float16)
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16)
gen_config = GenerationConfig(do_sample=False, top_p=1, temperature=1)
logger.info("loading model")
start = perf_counter()
model = AutoModelForCausalLM.from_pretrained(
ckpt, torch_dtype=torch.float16, generation_config=gen_config
model_id, torch_dtype=torch.float16, generation_config=gen_config
).eval()
model.to(device)
torch.cuda.synchronize()
end = perf_counter()
model_load_time = end - start
logger.info(f"loaded model in: {model_load_time}s")

tokenizer = AutoTokenizer.from_pretrained(ckpt)
tokenizer = AutoTokenizer.from_pretrained(model_id)

prompt = "Why dogs are so cute?"
inputs = tokenizer(prompt, return_tensors="pt").to(device)
@@ -326,36 +314,27 @@ def decode_one_token(model, cur_token, cache_position, past_key_values):
logger.info(f"completed second compile generation in: {fourth_compile_generate_time}s")
logger.info(f"generated: {tokenizer.batch_decode(output.cpu().tolist())}")

cur.execute(
"""
INSERT INTO model_measurements (
benchmark_id,
measurements
) VALUES (%s, %s)
""",
(
benchmark_id,
{
"model_load_time": model_load_time,
"first_eager_forward_pass_time_secs": first_eager_fwd_pass_time,
"second_eager_forward_pass_time_secs": second_eager_fwd_pass_time,
"first_eager_generate_time_secs": first_eager_generate_time,
"second_eager_generate_time_secs": second_eager_generate_time,
"time_to_first_token_secs": time_to_first_token,
"time_to_second_token_secs": time_to_second_token,
"time_to_third_token_secs": time_to_third_token,
"time_to_next_token_mean_secs": mean_time_to_next_token,
"first_compile_generate_time_secs": first_compile_generate_time,
"second_compile_generate_time_secs": second_compile_generate_time,
"third_compile_generate_time_secs": third_compile_generate_time,
"fourth_compile_generate_time_secs": fourth_compile_generate_time,
},
),
metrics_recorder.collect_model_measurements(
benchmark_id,
{
"model_load_time": model_load_time,
"first_eager_forward_pass_time_secs": first_eager_fwd_pass_time,
"second_eager_forward_pass_time_secs": second_eager_fwd_pass_time,
"first_eager_generate_time_secs": first_eager_generate_time,
"second_eager_generate_time_secs": second_eager_generate_time,
"time_to_first_token_secs": time_to_first_token,
"time_to_second_token_secs": time_to_second_token,
"time_to_third_token_secs": time_to_third_token,
"time_to_next_token_mean_secs": mean_time_to_next_token,
"first_compile_generate_time_secs": first_compile_generate_time,
"second_compile_generate_time_secs": second_compile_generate_time,
"third_compile_generate_time_secs": third_compile_generate_time,
"fourth_compile_generate_time_secs": fourth_compile_generate_time,
},
)
conn.commit()
conn.close()
except Exception as e:
logger.error(f"Caught exception: {e}")
continue_metric_collection.set()
if metrics_thread is not None:
metrics_thread.join()
metrics_recorder.close()

0 comments on commit 796affb

Please sign in to comment.