236 lines
5.3 KiB
Text
236 lines
5.3 KiB
Text
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"id": "25d01f84-863b-4238-964b-e24425bb8107",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import numpy as np\n",
|
|
"import tensorflow as tf"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"id": "242c5eb2-2531-481d-8c21-0b0178ac2db5",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"np.random.seed(1793)\n",
|
|
"tf.random.set_seed(1793)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"id": "49fa101c-15a9-47be-b254-6cbde6dd04e9",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import trieste\n",
|
|
"from trieste.utils.objectives import branin\n",
|
|
"\n",
|
|
"search_space = trieste.space.Box([0, 0], [1, 1])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"id": "050686af-8c93-4996-b852-c925d7cc7e98",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def new_branin(data):\n",
|
|
" print(data.shape)\n",
|
|
" return branin(data)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"id": "cd44cddb-5090-4e4f-b651-f7d57f145b70",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"(5, 2)\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"from trieste.acquisition.rule import OBJECTIVE\n",
|
|
"\n",
|
|
"observer = trieste.utils.objectives.mk_observer(new_branin, OBJECTIVE)\n",
|
|
"\n",
|
|
"num_initial_points = 5\n",
|
|
"initial_query_points = search_space.sample(num_initial_points)\n",
|
|
"initial_data = observer(initial_query_points)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"id": "696dfd67-0e37-42f0-af5b-64937bffc2aa",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Building a model\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"import gpflow\n",
|
|
"\n",
|
|
"def build_model(data):\n",
|
|
" print(\"Building a model\")\n",
|
|
" variance = tf.math.reduce_variance(data.observations)\n",
|
|
" kernel = gpflow.kernels.Matern52(variance=variance, lengthscales=[0.2, 0.2])\n",
|
|
" gpr = gpflow.models.GPR(data.astuple(), kernel, noise_variance=1e-5)\n",
|
|
" gpflow.set_trainable(gpr.likelihood, False)\n",
|
|
"\n",
|
|
" return {OBJECTIVE: {\n",
|
|
" \"model\": gpr,\n",
|
|
" \"optimizer\": gpflow.optimizers.Scipy(),\n",
|
|
" \"optimizer_args\": {\n",
|
|
" \"minimize_args\": {\"options\": dict(maxiter=100)},\n",
|
|
" },\n",
|
|
" }}\n",
|
|
"\n",
|
|
"model = build_model(initial_data[OBJECTIVE])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 7,
|
|
"id": "5fa0dabe-9930-45bc-86cf-843e1dbe10a1",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"(1, 2)\n",
|
|
"(1, 2)\n",
|
|
"(1, 2)\n",
|
|
"(1, 2)\n",
|
|
"(1, 2)\n",
|
|
"(1, 2)\n",
|
|
"(1, 2)\n",
|
|
"(1, 2)\n",
|
|
"(1, 2)\n",
|
|
"(1, 2)\n",
|
|
"(1, 2)\n",
|
|
"(1, 2)\n",
|
|
"(1, 2)\n",
|
|
"(1, 2)\n",
|
|
"(1, 2)\n",
|
|
"(1, 2)\n",
|
|
"(1, 2)\n",
|
|
"(1, 2)\n",
|
|
"(1, 2)\n",
|
|
"(1, 2)\n",
|
|
"(1, 2)\n",
|
|
"(1, 2)\n",
|
|
"(1, 2)\n",
|
|
"(1, 2)\n",
|
|
"(1, 2)\n",
|
|
"Optimization completed without errors\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"bo = trieste.bayesian_optimizer.BayesianOptimizer(observer, search_space)\n",
|
|
"\n",
|
|
"result = bo.optimize(25, initial_data, model)\n",
|
|
"dataset = result.try_get_final_datasets()[OBJECTIVE]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 8,
|
|
"id": "7d1bcebb-8ac5-4b60-9589-922c2a5d5429",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"query point: [0.54162649 0.14935401]\n",
|
|
"observation: [0.40162417]\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"query_points = dataset.query_points.numpy()\n",
|
|
"observations = dataset.observations.numpy()\n",
|
|
"\n",
|
|
"arg_min_idx = tf.squeeze(tf.argmin(observations, axis=0))\n",
|
|
"\n",
|
|
"print(f\"query point: {query_points[arg_min_idx, :]}\")\n",
|
|
"print(f\"observation: {observations[arg_min_idx, :]}\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 12,
|
|
"id": "297fc50a-f755-40f1-8c9f-aaed111951cf",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"<tf.Tensor: shape=(5, 1), dtype=float64, numpy=\n",
|
|
"array([[ 52.74206955],\n",
|
|
" [ 20.91833473],\n",
|
|
" [110.23393455],\n",
|
|
" [225.05283879],\n",
|
|
" [ 52.49971937]])>"
|
|
]
|
|
},
|
|
"execution_count": 12,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"initial_data[OBJECTIVE].observations"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "624720e0-5600-4a74-a2ce-6d09e248c069",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.9.3"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|