個人メモ
Hugging Faceで公開されているLlama2のモデルを使用して、4bit量子化を有効にして、70Bのモデルを1GPU(A100)で推論する方法について記述する。
dockerコンテナ作成
NVIDIAのPyTorchイメージを使用してDockerコンテナを作成する。
※ホストのドライババージョンが古いため、少し前のイメージを使用している。
コマンド例
docker run --gpus all --network host -v /work:/work -w /work -it nvcr.io/nvidia/pytorch:22.12-py3
PyTorchバージョンアップ
xformersがpytorch 2.0.1を要求するためPyTorchをアンインストールしてからインストール
pip uninstall torch torchvision torchtext torch-tensorrt pip install torch torchvision torchtext torchaudio --index-url https://download.pytorch.org/whl/cu118
transformer-engineアンインストール
参考:https://github.com/microsoft/TaskMatrix/issues/116
pip uninstall transformer-engine
ライブラリインストール
pip install transformers sentencepiece accelerate xformers bitsandbytes
スクリプト作成
モデルを読み込む際に、「load_in_4bit=True」を指定することで4bit量子化が有効になる。
import argparse parser = argparse.ArgumentParser() parser.add_argument("model") parser.add_argument("prompt") args = parser.parse_args() import transformers from transformers import AutoModelForCausalLM from transformers import AutoTokenizer model = AutoModelForCausalLM.from_pretrained( args.model, load_in_4bit=True, device_map="auto" ) tokenizer = AutoTokenizer.from_pretrained(args.model) pipeline = transformers.pipeline( "text-generation", model=model, tokenizer=tokenizer, ) prompt = open(args.prompt, "r").read() # 推論の実行 sequences = pipeline( prompt, do_sample=True, top_k=10, num_return_sequences=1, eos_token_id=tokenizer.eos_token_id, max_length=200, ) print(sequences[0]["generated_text"])
プロンプト作成
公式Blogのプロンプト記述方法を参考に、以下の通りプロンプトを作成する。
<s>[INST] <<SYS>> あなたは有能なアシスタントです。 日本語で回答してください。 <</SYS>> 日本で一番高い山は? [/INST]
prompt_ja.txtとして保存する。
スクリプト実行
コマンド例:
python question.py meta-llama/Llama-2-70b-chat-hf prompt_ja.txt
応答例:
The highest mountain in Japan is 富士山 (Fuji-san) with an elevation of 3,776 meters (12,421 feet) above sea level.
日本語で回答してくれないことが多い。
A100だと、1GPUで実行できる。
V100だと、2GPU必要だった。
4bit量子化しない場合
4bit量子化しない場合は、「load_in_4bit=True」を削除して、「torch_dtype=torch.float16」を指定する。
4bit量子化しない場合、A100だと2GPU、V100だと5GPU必要だった。
まとめ
Llama2の70Bモデルを4bit量子化して1GPU(A100)で実行する方法について記述した。
Llama2は、そのままだと日本語では回答できないことが多いため、日本語で使うにはファインチューニングが必要そうである。
日本語のファインチューニングについても別途試したい。