TadaoYamaokaの開発日記

個人開発しているスマホアプリや将棋AIの開発ネタを中心に書いていきます。

拡散モデルで将棋の方策を学習する その2

前回の続き。
tadaoyamaoka.hatenablog.com

前回訓練損失が下がるところまで確認したが、正解率は0%のままだった。
その後学習を継続し、126,504ステップ時点で、訓練損失と正解率は以下の通りとなった。

学習結果

訓練損失


正解率

floodgateの棋譜に対する、正解率も約25%まで上昇している。
バッチサイズは32のため、学習した局面は4,048,128局面である。

拡散モデルでも、将棋の方策を学習できることが確認できた。

ResNetとの比較

ResNet30ブロック384フィルタのモデルを同じ局面数だけ学習した場合、正解率は約37%になる。
拡散モデルの方が学習に時間がかかっている。

2023/07/16 17:27:24     INFO    epoch = 1, steps = 4000, train loss = 2.1729847, 0.5884042, 0.5848547, 2.7602069, test loss = 2.0987635, 0.5968569, 0.6581656, 2.7160361, test accuracy = 0.3750000, 0.6640625

訓練時間

10kステップの学習に約1時間かかっている。

ResNet30ブロック384フィルタのモデルでは、バッチサイズ1024で同じ局面数した場合にかかる時間は、2分25秒である。

拡散モデルの学習は25倍近く時間がかかる。

まとめ

拡散モデルでも将棋の方策を学習できることが確認できた。
ただし、ResNetのモデルに比べて学習に非常に時間がかかる(25倍以上)ことがわかった。

大量に棋譜を学習させてどれくらいの精度になるか確認したいと思っているが、学習に時間がかかりすぎるため、PyTorch Lightningを使ってマルチGPUで学習できるようにしようと思う。

続く。

拡散モデルで将棋の方策を学習する

拡散モデルで、将棋の方策を学習できないか試してみた。

拡散モデル

拡散モデルは、高品質で多様なサンプルを生成できる生成モデルである。
昨年8月にStable Diffusionが公開されたことで注目を集めている。

拡散モデルは、確率微分方程式によって表される確率分布を近似するモデルで、モード崩壊を起こさず多様な分布を学習できるという特徴がある。

また、プロンプトと呼ばれるテキストにより条件付けを行い、テキストに従った画像を生成できる。

将棋の方策

将棋の方策は、座標と移動方向の組み合わせで表現でき、dlshogiで採用している表現方法では2187次元になる。
つまり、指し手は、局面によって条件づけられた2187次元の確率分布からサンプリングを行っていることになる。

拡散モデルの可能性

条件付けを行い高次元の確率分布からサンプリングを行うという仕組みは、将棋の方策においても適用できると考える。
将棋は、詰みのある局面や入玉した局面では指し手の性質が変わるという特徴があるが、拡散モデルの表現力であれば、異なるモードの状態でもうまく学習できるはずである。
現状のdlshogiでは、詰みのある局面は学習が難しいため、訓練データからは除外し、対局時に詰み探索に頼っている。

モデル構成

方策を複数チャンネルの9x9の画像として表現して、画像生成の拡散モデルと同じように学習する。

拡散過程では、方策からノイズに変換する。
逆拡散過程では、ノイズから方策に変換する。

条件付けは、局面をCNNで特徴マップに変換したものを使用する。
この部分はdlshogiの学習済みモデルのバックボーンの層を流用する。

実装

diffusersを使って、サンプルコードを参考にしながら実装した。

import argparse
import numpy as np
import types
from tqdm import tqdm

from diffusers import (
    DDPMScheduler,
    UNet2DConditionModel,
)
from diffusers.utils import randn_tensor

import torch
import torch.nn.functional as F
from torch.optim import AdamW
from torch.optim.lr_scheduler import StepLR
from torch.utils.tensorboard import SummaryWriter

from dlshogi.common import *
from dlshogi.data_loader import Hcpe3DataLoader
from dlshogi.data_loader import DataLoader
from dlshogi import serializers
from dlshogi.network.policy_value_network import policy_value_network


parser = argparse.ArgumentParser()
parser.add_argument("cache")
parser.add_argument("test_data")
parser.add_argument("model")
parser.add_argument("--network", default="resnet30x384_relu")
parser.add_argument("-e", "--epoch", type=int, default=1)
parser.add_argument("-g", "--gpu", type=int, default=0)
parser.add_argument("-b", "--batch_size", type=int, default=32)
parser.add_argument("--eval_batch_size", type=int, default=16)
parser.add_argument('--eval_interval', type=int, default=1000)
parser.add_argument("--num_inference_steps", type=int, default=20)
args = parser.parse_args()


if args.gpu >= 0:
    device = torch.device(f"cuda:{args.gpu}")
else:
    device = torch.device("cpu")


train_len, actual_len = Hcpe3DataLoader.load_files([], cache=args.cache)
train_data = np.arange(train_len, dtype=np.uint32)
test_data = np.fromfile(args.test_data, dtype=HuffmanCodedPosAndEval)

train_dataloader = Hcpe3DataLoader(train_data, args.batch_size, device, shuffle=True)
test_dataloader = DataLoader(test_data, args.eval_batch_size, device)


model = policy_value_network(args.network)


def forward(self, x1, x2):
    u1_1_1 = self.l1_1_1(x1)
    u1_1_2 = self.l1_1_2(x1)
    u1_2 = self.l1_2(x2)
    u1 = self.act(self.norm1(u1_1_1 + u1_1_2 + u1_2))
    return self.blocks(u1)


model.forward = types.MethodType(forward, model)
model.to(device)
serializers.load_npz(args.model, model)
model.eval()


unet = UNet2DConditionModel(
    sample_size=9,
    in_channels=27,
    out_channels=27,
    encoder_hid_dim=model.policy.in_channels * 9 * 9,
)
unet.to(device)

optimizer = AdamW(
    unet.parameters(),
    lr=1e-4,
    betas=(0.9, 0.999),
    weight_decay=1e-2,
    eps=1e-08,
)

lr_scheduler = StepLR(optimizer=optimizer, step_size=1, gamma=0.5)

noise_scheduler = DDPMScheduler(
    num_train_timesteps=1000,
    beta_start=0.00085,
    beta_end=0.012,
    beta_schedule="scaled_linear",
    clip_sample=False,
)

generator = torch.Generator(device=unet.device).manual_seed(0)

nl_loss = torch.nn.NLLLoss()
def accuracy(y, t):
    return (torch.max(y, 1)[1] == t).sum().detach().item() / len(t)

def eval():
    noise_scheduler = DDPMScheduler()

    x1, x2, policies, win, value = test_dataloader.sample()
    bsz = policies.shape[0]

    encoder_hidden_states = model(x1, x2).reshape(bsz, 1, -1)

    # inference (sample random noise and denoise)
    image = randn_tensor((bsz, 27, 9, 9), generator=generator, device=device)

    # set step values
    noise_scheduler.set_timesteps(args.num_inference_steps)

    unet.eval()
    for t in noise_scheduler.timesteps:
        # 1. predict noise model_output
        model_output = unet(image, t, encoder_hidden_states).sample

        # 2. compute previous image: x_t -> x_t-1
        image = noise_scheduler.step(model_output, t, image, generator=generator).prev_sample

    pred = image.reshape(bsz, -1)
    loss = nl_loss(torch.log(pred), policies)
    writer.add_scalar("loss/eval", loss.detach().item(), step)
    writer.add_scalar("accuracy/eval", accuracy(pred, policies), step)

    unet.train()

writer = SummaryWriter()


step = 0
for epoch in range(args.epoch):
    unet.train()
    for x1, x2, policies, win, value in tqdm(
        train_dataloader, total=train_len // args.batch_size, desc=f"epoch: {epoch}"
    ):
        bsz = policies.shape[0]
        policies = policies.reshape((bsz, 27, 9, 9))

        # Sample noise that we'll add to the latents
        noise = torch.randn_like(policies)

        # Sample a random timestep for each image
        timesteps = torch.randint(
            0,
            noise_scheduler.config.num_train_timesteps,
            (bsz,),
            device=policies.device,
        )
        timesteps = timesteps.long()

        # Add noise to the latents according to the noise magnitude at each timestep
        # (this is the forward diffusion process)
        noisy_policies = noise_scheduler.add_noise(policies, noise, timesteps)

        # Get the embedding for position
        encoder_hidden_states = model(x1, x2).reshape(bsz, 1, -1)

        target = noise

        # Predict the noise residual and compute loss
        model_pred = unet(noisy_policies, timesteps, encoder_hidden_states).sample

        loss = F.mse_loss(model_pred, target, reduction="mean")

        # Backpropagate
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        step += 1

        writer.add_scalar("loss/train", loss.detach().item(), step)

        if step % args.eval_interval == 0:
            eval()

    lr_scheduler.step()

条件付けは、テキストではなくdlshogiのモデルのバックボーンを使用している。
duffusersは、埋め込みのテンソルがシーケンスになっている必要があるため、シーケンスのサイズ=1とした。

Stable Diffusionのように、VAEで潜在空間への圧縮は行わず、直接方策の確率を入力している。

学習結果

まだ学習を少し試しただけだが、訓練損失は下がることを確認した。

逆拡散過程で生成した方策の正解率は0で、今のところ学習できていない。

確認したところ、出力に負の値が含まれていたため、方策の確率を直接学習するよりlogits空間に変換した方がよいかもしれない。

まとめ

拡散モデルによる将棋の方策の学習を試してみた。
推論速度が遅いため、大会向けの将棋AIとしては速度面で使い物にならなそうであるが、精度で既存モデルを上回るか検証したい。
今のところの感触は学習が難しそうである。

続く。
tadaoyamaoka.hatenablog.com

ChatGPT APIのFunction callingを試す

Chat Completions APIに新しく追加された機能である「Function calling」を試してみた。
Function calling and other API updates

Function calling

Bingチャットを使っていると、質問に応じて必要なときに検索を行う処理を行っている。同じように状況に応じてアクションをとる処理を実装しようとすると、プロンプトエンジニアリングを頑張る必要がある。
今回追加された「Function calling」は、このような用途をプロンプトエンジニアリングなしで実現することができる。

APIの呼び出し時に渡すfunctionsパラメータに、どのような状況で関数を実行するかの説明を書いておけば、質問文がその状況に該当する場合に、finish_reasonがfunction_callになって、関数引数の値とともに返してくれる。

その値をもとにアクションを行って、さらに生成を続けることができる。

要約に使えそう

他のユースケースとしては、例えば文書に対する質問を行うチャットボットの場合、プロンプトに「要約してください」が与えられた場合に、質問に対する回答ではなく、要約の関数を定義しておけば、文書全体の要約処理を行うことができる。
要約の時だけ、トークン数が4倍になった16kのモデルを使うとかもできそうである。

他のユースケース

他にも、OpenAIの記事ではデータベース クエリに変換する処理や、テキストから構造化データを抽出する用途が紹介されている。

OpenAI Cookbook

OpenAI CookbookにFunction callingのノートブックが追加されたので試してみた。
openai-cookbook/examples/How_to_call_functions_with_chat_models.ipynb at main · openai/openai-cookbook · GitHub

天気を検索

AIPのパラメータに渡す関数定義を以下のように行う。
関数定義のdescriptionに記述している「Get the current weather」に該当する質問の場合にこの関数が返されるようになる。

functions = [
    {
        "name": "get_current_weather",
        "description": "Get the current weather",
        "parameters": {
            "type": "object",
            "properties": {
                "location": {
                    "type": "string",
                    "description": "The city and state, e.g. San Francisco, CA",
                },
                "format": {
                    "type": "string",
                    "enum": ["celsius", "fahrenheit"],
                    "description": "The temperature unit to use. Infer this from the users location.",
                },
            },
            "required": ["location", "format"],
        },
    }
]


ユーザの質問が、「what is the weather like today」の場合、関数の引数locationが特定できないため、ChatGPTからは場所を聞く質問が返される。

conversation.add_message("user", "what is the weather like today")
{'role': 'assistant', 'content': 'Where are you currently located?'}

ユーザが場所を回答すると、関数を返す条件が整ったので、関数が返される。

conversation.add_message("user", "I'm in Glasgow, Scotland")
{'index': 0,
 'message': {'role': 'assistant',
  'content': None,
  'function_call': {'name': 'get_current_weather',
   'arguments': '{\n  "location": "Glasgow, Scotland",\n  "format": "celsius"\n}'}},
 'finish_reason': 'function_call'}

返された関数のパラメータを元に天気を検索する外部APIを呼び出すことができる。

SQLクエリを生成

functionsにSQLクエリを実行する関数を定義する。

functions = [
    {
        "name": "ask_database",
        "description": "Use this function to answer user questions about music. Input should be a fully formed SQL query.",
        "parameters": {
            "type": "object",
            "properties": {
                "query": {
                    "type": "string",
                    "description": f"""
                            SQL query extracting info to answer the user's question.
                            SQL should be written using this database schema:
                            {database_schema_string}
                            The query should be returned in plain text, not in JSON.
                            """,
                }
            },
            "required": ["query"],
        },
    }
]

埋め込み変数になっている{database_schema_string}の部分には、テーブルスキーマを挿入する。
ChatGPTがテーブルスキーマを知っておくことが重要ということである。


システムプロンプトを以下のように入力する。

agent_system_message = """You are ChinookGPT, a helpful assistant who gets answers to user questions from the Chinook Music Database.
Provide as many details as possible to your users
Begin!"""

(あなたは ChinookGPT です。Chinook 音楽データベースからユーザーの質問に回答する便利なアシスタントです。
できるだけ多くの詳細をユーザーに提供します)

ユーザメッセージを以下のように入力する。

sql_conversation.add_message(
    "user", "Hi, who are the top 5 artists by number of tracks"
)

(こんにちは、トラック数のトップ 5 アーティストは誰ですか)


以下のようなクエリの関数が返却される。

{'query': 'SELECT artists.Name, COUNT(tracks.TrackId) AS num_tracks FROM artists JOIN albums ON artists.ArtistId = albums.ArtistId JOIN tracks ON albums.AlbumId = tracks.AlbumId GROUP BY artists.Name ORDER BY num_tracks DESC LIMIT 5;'}

これを元に、データベースにクエリを実行することで、データベースに基づいた回答ができる。

ただし、直接生成されたクエリを実行するとSQLインジェクションなどのセキュリティ上の問題が起きるため、SQLのパラメータのみ生成して、SQL文は別で作りこむ方がよい。

複数の関数を定義して連携する

複数の関数を定義して、連携させることができる。
例では、arXivからキーワードに該当する論文一覧を取得する関数と、取得した論文から関連が高い論文を要約する関数を定義して、先に一覧取得していない場合は、要約を行わないようにしている。

# Initiate our get_articles and read_article_and_summarize functions
arxiv_functions = [
    {
        "name": "get_articles",
        "description": """Use this function to get academic papers from arXiv to answer user questions.""",
        "parameters": {
            "type": "object",
            "properties": {
                "query": {
                    "type": "string",
                    "description": f"""
                            User query in JSON. Responses should be summarized and should include the article URL reference
                            """,
                }
            },
            "required": ["query"],
        },
    },
    {
        "name": "read_article_and_summarize",
        "description": """Use this function to read whole papers and provide a summary for users.
        You should NEVER call this function before get_articles has been called in the conversation.""",
        "parameters": {
            "type": "object",
            "properties": {
                "query": {
                    "type": "string",
                    "description": f"""
                            Description of the article in plain text based on the user's query
                            """,
                }
            },
            "required": ["query"],
        },
    }
]

(1つ目の関数の説明:ユーザーの質問に答えるために arXiv から学術論文を取得するには、この関数を使用します。
 2つ目の関数の説明:論文全体を読み、ユーザーに概要を提供するには、この機能を使用します。
         会話内で get_articles が呼び出される前に、この関数を決して呼び出さないでください。)
※ノートブックにはバグがあり、functionsが配列になっていなかったので修正している。


システムプロンプトに以下の内容を入力する。

paper_system_message = """You are arXivGPT, a helpful assistant pulls academic papers to answer user questions.
You summarize the papers clearly so the customer can decide which to read to answer their question.
You always provide the article_url and title so the user can understand the name of the paper and click through to access it.
Begin!"""
paper_conversa

(あなたは arXivGPT です。学術論文を引き出してユーザーの質問に答える役に立つアシスタントです。
顧客が質問に答えるためにどれを読めばよいかを判断できるように、論文を明確に要約します。
ユーザーが論文の名前を理解し、クリックして論文にアクセスできるように、必ずarticle_url と title を指定します。
始めよう!)

ユーザメッセージに以下の内容を入力する。

paper_conversation.add_message("user", "Hi, how does PPO reinforcement learning work?")

(こんにちは、PPO 強化学習はどのように機能しますか?)

get_articlesが返るので、記事の一覧を取得し、メッセージ履歴にroleをfunctionとしたメッセージを追加し、その際contentに取得した記事の内容を設定する。

messages.append(
    {
        "role": "function",
        "name": full_message["message"]["function_call"]["name"],
        "content": str(results),
    }

そして、会話を続けると以下のようにRead moreが論文のリンクになった論文一覧が生成される。

Here are some papers that explain how Proximal Policy Optimization (PPO) reinforcement learning works:

"Proximal Policy Optimization and its Dynamic Version for Sequence Generation" - This paper proposes using PPO for sequence generation tasks and introduces a dynamic approach for PPO. It demonstrates the efficacy of PPO in conditional sequence generation tasks. Read more

"CIM-PPO: Proximal Policy Optimization with Liu-Correntropy Induced Metric" - This article analyzes the asymmetry effect of KL divergence on PPO's objective function and introduces a new algorithm called CIM-PPO that incorporates the theory of correntropy. It compares the effectiveness of CIM-PPO with KL-PPO and Clip-PPO. Read more

"A2C is a special case of PPO" - This paper challenges the common understanding that A2C and PPO are separate algorithms and shows that A2C is actually a special case of PPO. It provides theoretical justifications and empirical experiments to support this claim. Read more

"Proximal Policy Optimization via Enhanced Exploration Efficiency" - This paper addresses the exploration ability of PPO and proposes an exploration enhancement mechanism based on uncertainty estimation. It introduces the proximal policy optimization algorithm with intrinsic exploration module (IEM-PPO) and evaluates its performance on multiple tasks. Read more

"Neural PPO-Clip Attains Global Optimality: A Hinge Loss Perspective" - This paper provides the first global convergence rate analysis of PPO-Clip under neural function approximation. It presents a reinterpretation of PPO-Clip using hinge loss and introduces a two-step policy improvement scheme for convergence analysis. Experimental results validate the approach on various RL benchmark tasks. Read more

Feel free to click on the links to access the papers for more details.


続けて、ユーザメッセージに同じ質問を入力する。

paper_conversation.add_message("user", "Hi, how does PPO reinforcement learning work?")

(こんにちは、PPO 強化学習はどのように機能しますか?)

今度は、read_article_and_summarizeが返るので、取得済みの記事一覧から返されたread_article_and_summarizeのパラメータのqueryに近いタイトルの記事を探して要約を実行する。
以下のように要約結果が得られる。

Core Argument:

The paper discusses the use of proximal policy optimization (PPO) in sequence generation tasks, specifically in the context of chit-chat chatbots.
The authors argue that PPO is a more efficient reinforcement learning algorithm compared to policy gradient, which is commonly used in these tasks.
They propose a dynamic approach for PPO (PPO-dynamic) and demonstrate its efficacy in synthetic experiments and chit-chat chatbot tasks.
Evidence:

PPO-dynamic achieves a high precision score in a synthetic counting task, comparable to other algorithms like REINFORCE and MIXER.
In the chit-chat chatbot task, PPO-dynamic achieves a slightly higher BLEU-2 score than REINFORCE and PPO.
The learning curves of PPO and PPO-dynamic are more stable than policy gradient, and PPO-dynamic converges faster.
Conclusions:

PPO is a better optimization method for sequence learning compared to policy gradient.
PPO-dynamic further improves the optimization process by dynamically adjusting the hyperparameters.
PPO can be used as a new optimization method for GAN-based sequence learning for better performance.

まとめ

Chat Completions APIに新しく追加された「Function calling」を試した。
Function callingにより、Bingチャットのようにユーザの質問文に応じて必要な場合に検索を呼び出すような処理をプロンプトエンジニアリングなしで実行できるようになる。

Bingチャットや、ChatGPT Plusのプラグインは、内部では前からこの機能を使って実装されていたのかもしれない。
APIでも同じような処理が実装できるようになったのでアプリケーションの応用範囲も広がりそうである。

dlshogiで量子化認識トレーニングを試す

先日、PyTorchでFXグラフモードで量子化認識トレーニングを試して、精度が大きく落ちることなく、CPUでの推論が速くなることを確認した。

そこで、dlshogiのモデルで量子化認識トレーニングを試して、CPUで強くなるかを検証した。

量子化認識トレーニン

先日の記事にも書いたが、量子化認識トレーニングは、訓練時に量子化(浮動小数点を整数にする)のスケールパラメータを学習することで、精度を保ったまま整数化により推論を高速する手法である。

PyTorchのFXグラフモードを使うことで、比較的簡単に量子化認識トレーニングを実装できる。

ONNX化の課題

dlshogiのモデルを量子化認識トレーニングすると、ONNXに出力する際、

Exporting the operator 'quantized::batch_norm2d_relu' to ONNX opset version 14 is not supported.

というエラーが発生し、ONNXに変換できなかった。

量子化した後のTorchScriptモデルで使われているquantized::batch_norm2d_reluオペレータが、ONNXに対応していないことが原因のようだ。

ONNX化はあきらめて、LibTorchで推論することにした。
LibTorchでも、TorchScriptモデルを使えば通常モデルに比べて高速に推論することができる。

強さ測定

CPUで実行することを考慮して、dlshogiの30ブロック384フィルタのモデルから5ブロック96フィルタおよび20ブロック256フィルタのモデルに蒸留を行った。

量子化前のモデルも含めて、開始局面から16手ランダムあり、1手2秒で連続対局を行った。
ハードウェアは、CPU 4コア(Xeon E3-1275 v6)、メモリ 64GBのマシンを使用した。
バッチサイズは2、探索スレッド数は1、ルートでの詰み探索あり、PVの詰み探索ありとした。

   # PLAYER             :  RATING  ERROR  POINTS  PLAYED   (%)  CFS(%)    W    D    L  D(%)
   1 dlshogi_20bqat     :   280.7   49.2   472.0     800    59     100  465   14  321     2
   2 dlshogi_5bfp32     :   122.9   56.8   273.5     802    34     100  268   11  523     1
   3 dlshogi_5bqat      :    34.8   65.0   106.5     402    26      85  106    1  295     0
   4 dlshogi_20bfp32    :     0.0   ----    59.5     400    15     ---   57    5  338     1

White advantage = 49.19 +/- 9.52
Draw rate (equal opponents) = 1.45 % +/- 0.31

※同系列で偏らないように基準ソフトとして3コマ関係のソフトも加えているが省略している。
※末尾fp32が量子化前のモデル、qatが量子化したモデル

5ブロックのモデルは、量子化するとかえって弱くなった。
20ブロックのモデルは、量子化により強くなっている。
量子化前は、5ブロックの方が強い。

精度

各モデルのfloodgateの棋譜に対する正解率は以下の通り。

モデル 方策正解率 価値正解率
5ブロック(FP32) 0.4570605 0.7262039
5ブロック(量子化) 0.4400687 0.7097659
20ブロック(FP32) 0.5414595 0.7704120
20ブロック(量子化) 0.5317605 0.7590858

5ブロックのモデルでは、量子化により方策、価値ともに正解率がそれぞれ1.7%と1.6%低下している。
20ブロックのモデルでは、量子化により方策、価値ともに正解率がそれぞれ1.0%と1.1%低下している。
20ブロックの方が量子化による正解率の低下が小さい。

NPS

平手開始局面で、10秒探索させた際のNPSは以下の通り。

モデル NPS
5ブロック(FP32) 2049
5ブロック(量子化) 2453
20ブロック(FP32) 38
20ブロック(量子化) 165

5ブロックのモデルでは、量子化によりNPSが1.2倍になっている。
20ブロックのモデルでは、量子化によりNPSが4.34倍になっている。

20ブロックのモデルは、量子化による速度向上の効果が高い。
これは、モデルの推論時間が探索の多くを占めるためと思われる。

20ブロックのF32モデルは、NPSが38しかでておらず、少ない探索ではMCTSの結果が安定しないため、比較したモデルで一番弱かったのも納得できる。

floodgateに放流

比較したモデルで最もレーティングが高かった20ブロックの量子化モデルをfloodgateに放流して、レーティングを計測した。

floodgateのレーティングは約3600であった。

NNUE系は2コアでも約3800であるので、CPUではまだNNUE系を上回っていない。

まとめ

dlshogiで量子化認識トレーニングを試して、量子化前後の強さを比較した。
結果、20ブロックのモデルで量子化の効果が大きいことが確認できた。
5ブロックでは推論は速くなるものの精度低下により逆に弱くなった。

dlshogiを30ブロックの最新モデルを20ブロックに蒸留して量子化すると、4コアのCPUでNPSが165しかないがレーティングが約3600あることがわかった。
モデル精度を向上することで、CPUでもNNUE系を上回ることができるか引き続き検証したい。

2023/6/15追記

対戦数232でレーティングは3719となった。

React+TailwindCSSで高さ可変のテキストエリアを作る

ほぼ個人メモ。

HTMLのtextareaは、入力に応じて高さを自動で調整する機能がない。
ChatGPTの入力欄のように、高さが自動で変わるテキストエリアを作るには、CSSJavaScriptで実装する必要がある。

JavaScriptで高さを計算するのはややこしいので、非表示のdivタグに内容をコピーして、そのタグの高さにflexのrelativeとabsoluteを使って合わせる方法がある。
The Cleanest Trick for Autogrowing Textareas | CSS-Tricks - CSS-Tricks

このサンプルでは、Alpine.jsを使っていたので、Reactで実装する方法を調べた。

React+TailwindCSSで実装

Reactとlessを使って実装する以下のページを参考に、React+TailwindCSSで実装した。
入力文字の行数で縦幅が可変するtextareaのReactコンポーネント - Qiita

import React, { useRef } from 'react'

const App: React.FC = () => {
  const hiddenInput = useRef<HTMLDivElement>(null)

  return (
    <>
      <div className='relative w-96 border flex'>
        <div className='invisible min-h-[3em] overflow-x-hidden whitespace-pre-wrap break-words p-3' aria-hidden={true} ref={hiddenInput}></div>
        <textarea className='absolute top-0 w-full h-full resize-none p-3' onChange={ e => { if (hiddenInput.current) hiddenInput.current.textContent = e.target.value + '\u200b' }}></textarea>
      </div>
    </>
  )
}

export default App

divタグにコピーする際に、末尾に幅ゼロの文字('\u200b')がないと、末尾が空行の場合、高さがずれるため追加している。

【論文】Tree of Thoughts: Deliberate Problem Solving with Large Language Models

LLMで思考ゲームのような推論が必要なタスクを可能にする手法に関する論文「Tree of Thoughts: Deliberate Problem Solving with Large Language Models」を読んだ際のメモ。

LLMはアイディアを出したりする創作的なタスクや、要約や翻訳のような変換のタスクは得意だが、論理的な思考が必要なタスクは比較的苦手である。
Chain-of-thought (CoT)のようなプロンプトエンジニアリングによって、数学の問題もある程度解けるようになっているが、例えば将棋のような思考ゲームをプレイするようなことはできない。

この研究では、LLMとヒューリステックな探索を組み合わせることで、Game of 24という簡単な思考ゲームを解くことができる手法を提案している。

LLMの特性として、トークンを順番生成する過程で、途中で間違いに気づいても後戻りできずそのまま間違った回答をすることがある。
生成した終わった後に、内容が間違いかどうかは判断することはできる。
そこで、生成した内容をLLM自身で評価して、やり直すことができれば、正解を出せる可能性が高まる。

Game of 24

Game of 24は、4つの1桁数字を使用して、四則演算を組み合わせて答えが24になる数式を作るゲームである。

例:4 4 6 8
⇒(4 + 8) * (6 - 4) = 24

GPT4に解かせてみる

まずは、GPT4にfew-shotラーニングを使って、単一のプロンプトで解かせてみる。

Use numbers and basic arithmetic operations (+ - * /) to obtain 24.
Input: 4 4 6 8
Answer: (4 + 8) * (6 - 4) = 24
Input: 2 9 10 12
Answer: 2 * 12 * (10 - 9) = 24
Input: 4 9 10 13
Answer: (13 - 9) * (10 - 4) = 24
Input: 1 4 8 8
Answer: (8 / 4 + 1) * 8 = 24
Input: 5 5 5 9
Answer: 5 + 5 + 5 + 9 = 24
Input: 5 5 4 3

回答:

Answer: (5 - 3) * (5 + 4) = 24

間違った答えが出力された。

CoTで段階的に思考させてみる

CoTで段階的に答えを出させてみた。

Use numbers and basic arithmetic operations (+ - * /) to obtain 24. Each step, you are only allowed to choose two of the remaining numbers to obtain a new number.
Input: 4 4 6 8
Steps:
4 + 8 = 12 (left: 4 6 12)
6 - 4 = 2 (left: 2 12)
2 * 12 = 24 (left: 24)
Answer: (6 - 4) * (4 + 8) = 24
Input: 2 9 10 12
Steps:
12 * 2 = 24 (left: 9 10 24)
10 - 9 = 1 (left: 1 24)
24 * 1 = 24 (left: 24)
Answer: (12 * 2) * (10 - 9) = 24
Input: 4 9 10 13
Steps:
13 - 10 = 3 (left: 3 4 9)
9 - 3 = 6 (left: 4 6)
4 * 6 = 24 (left: 24)
Answer: 4 * (9 - (13 - 10)) = 24
Input: 1 4 8 8
Steps:
8 / 4 = 2 (left: 1 2 8)
1 + 2 = 3 (left: 3 8)
3 * 8 = 24 (left: 24)
Answer: (1 + 8 / 4) * 8 = 24
Input: 5 5 5 9
Steps:
5 + 5 = 10 (left: 5 9 10)
10 + 5 = 15 (left: 9 15)
15 + 9 = 24 (left: 24)
Answer: ((5 + 5) + 5) + 9 = 24
Input: 5 5 4 3

回答:

Steps:
5 * 5 = 25 (left: 3 4 25)
25 - 4 = 21 (left: 3 21)
21 + 3 = 24 (left: 24)
Answer: (5 * 5 - 4) + 3 = 24

正解している。
これは簡単な例で、頭から順に数字を使っていけばよい。

少し難しくしてみる。

Input: 2 6 4 3

回答:

Steps:
6 / 2 = 3 (left: 3 3 4)
3 * 3 = 9 (left: 4 9)
4 * 9 = 36 (left: 36)
Unfortunately, it is not possible to get 24 from 36 using only one number and basic arithmetic operations.

今度は解けなかった。

提案手法

このゲームは、3手の先読みをするゲーム木探索で回答することができる。
提案手法では、LLMでゲーム木の各ノードの枝の候補の生成と、展開されたノードの評価をLLMで行い、終端ノードで正しく回答ができたかをLLM自身で評価する。

BFSの手法で各ステップで5つまで枝を残してゲーム木探索を行うことで回答を得る。

温度0.7で複数回サンプリングを行い、平均的なパフォーマンスで評価した結果、既存手法を大幅に上回る結果が得られた。

枝の候補を生成するプロンプト

枝の候補は以下のようなプロンプトで生成する。

Input: 2 8 8 14
Possible next steps:
2 + 8 = 10 (left: 8 10 14)
8 / 2 = 4 (left: 4 8 14)
14 + 2 = 16 (left: 8 8 16)
2 * 8 = 16 (left: 8 14 16)
8 - 2 = 6 (left: 6 8 14)
14 - 8 = 6 (left: 2 6 8)
14 /  2 = 7 (left: 7 8 8)
14 - 2 = 12 (left: 8 8 12)
Input: 5 5 4 3
Possible next steps:

出力:

5 + 5 = 10 (left: 10 4 3)
5 - 5 = 0 (left: 0 4 3)
5 * 5 = 25 (left: 25 4 3)
5 / 5 = 1 (left: 1 4 3)
4 + 5 = 9 (left: 9 5 3)
4 - 5 = -1 (left: -1 5 3)
4 * 5 = 20 (left: 20 5 3)
4 / 5 = 0.8 (left: 0.8 5 3)
3 + 5 = 8 (left: 8 5 4)
3 - 5 = -2 (left: -2 5 4)
3 * 5 = 15 (left: 15 5 4)
3 / 5 = 0.6 (left: 0.6 5 4)

ノードを評価するプロンプト

ノードは以下のようなプロンプトで3段階(sure/likely/impossible)で評価する。

Evaluate if given numbers can reach 24 (sure/likely/impossible)
10 14
10 + 14 = 24
sure
11 12
11 + 12 = 23
12 - 11 = 1
11 * 12 = 132
11 / 12 = 0.91
impossible
4 4 10
4 + 4 + 10 = 8 + 10 = 18
4 * 10 - 4 = 40 - 4 = 36
(10 - 4) * 4 = 6 * 4 = 24
sure
4 9 11
9 + 11 + 4 = 20 + 4 = 24
sure
5 7 8
5 + 7 + 8 = 12 + 8 = 20
(8 - 5) * 7 = 3 * 7 = 21
I cannot obtain 24 now, but numbers are within a reasonable range
likely
5 6 6
5 + 6 + 6 = 17
(6 - 5) * 6 = 1 * 6 = 6
I cannot obtain 24 now, but numbers are within a reasonable range
likely
10 10 11
10 + 10 + 11 = 31
(11 - 10) * 10 = 10
10 10 10 are all too big
impossible
1 3 3
1 * 3 * 3 = 9
(1 + 3) * 3 = 12
1 3 3 are all too small
impossible
25 2 4

出力:

25 + 2 + 4 = 31
(25 - 2) * 4 = 23 * 4 = 92
2 * 4 = 8, but adding or subtracting 8 to/from 25 does not reach 24
However, 25 - 2 / 4 = 25 - 0.5 = 24.5 is within a reasonable range
likely
終端ノードを評価するプロンプト
Use numbers and basic arithmetic operations (+ - * /) to obtain 24. Given an input and an answer, give a judgement (sure/impossible) if the answer is correct, i.e. it uses each input exactly once and no other numbers, and reach 24.
Input: 4 4 6 8
Answer: (4 + 8) * (6 - 4) = 24
Judge: 
sure
Input: 2 9 10 12
Answer: 2 * 12 * (10 - 9) = 24
Judge: 
sure
Input: 4 9 10 13
Answer: (13 - 9) * (10 - 4) = 24
Judge: 
sure
Input: 4 4 6 8
Answer: (4 + 8) * (6 - 4) + 1 = 25
Judge: 
impossible
Input: 2 9 10 12
Answer: 2 * (12 - 10) = 24
Judge: 
impossible
Input: 4 9 10 13
Answer: (13 - 4) * (10 - 9) = 24
Judge: 
impossible
Input: 5 5 4 3
Answer: (5 - 3) * (5 + 4) = 24
Judge:

出力:

sure

試した例だと、間違った結果が出力された。

他の例だと上手くいった。

Input: 2 6 4 3
Answer: (6 - 4) * (2 + 3) = 24
Judge:

出力:

impossible
Input: 2 6 4 3
Answer: (6 * 4) * (3 - 2) = 24
Judge:

出力:

sure

この例「2 6 4 3」は、上記で単一のプロンプトでは解けなかったが、終端ノードが正しく評価できているので、ToTの手法では解けそうである。

論文では、他にも、「4つの文から文章を生成するタスク」と「ミニクロスワード」のタスクでも効果があったと報告している。

まとめ

LLMは論理的な思考が苦手であったが、ToTの手法によりヒューリステックな探索が必要なタスクでも解ける可能性が示された。
思考ゲームのようなタスクは、既存のアルゴリズムで超人的なパフォーマンスが達成できているので、わざわざLLMに解かせる必要はないが、より複雑な現実世界のタスク(コーディング、データ分析、ロボティクスなど)で、LLMを活用してパフォーマンスを向上できる可能性がある。

この手法は、枝の生成と評価に人間が設計したプロンプトが必要で、ゲーム木の探索もLLMの外側で実装が必要なため、将来的にはLLM自体で解決して欲しい課題である。

Whisperでリアルタイムに文字起こしするアプリ その2

以前に、Whisperでリアルタイムに文字起こしするアプリを公開したが、記事にアクセスが多く需要があるようなので、ONNXに変換する処理とアプリのソースコードを公開した。
ONNXへの変換処理は、メールで記事を読んでもわからんので、ソースを公開して欲しいという要望をいただいていた。

ONNXに変換する処理

github.com
to_onnx.pyを実行すると、encoder.onnxとdecoder.onnxが出力される。
ソース中にサンプル音声のファイルパスがハードコードされているので、修正が必要である。

リアルタイムに文字起こしするアプリ

github.com
ソースはC#で実装しているので、Unityに組み込んでリアルタイム音声認識することもできると思う。

音声の無音期間を実装が簡単なため音声の音圧が低い区間で検出しているが、no_speechトークンと、timestampトークンを使って処理する方がよい。
実装しようとしていたが、時間がとれずにペンディング中である。

まとめ

Whisperでリアルタイムに文字起こしするアプリの記事アクセスが多く需要がありそうなため、ソースを公開した。
音声の無音期間の検出がいい加減なので、どなたかフォークして実装してもらえないかと期待している。