Updated Notebooks
This commit is contained in:
parent
32873283c9
commit
344bb9f8e4
157 changed files with 97952 additions and 2350 deletions
236
Notebooks/Trieste_GP_opt.ipynb
Normal file
236
Notebooks/Trieste_GP_opt.ipynb
Normal file
|
@ -0,0 +1,236 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "25d01f84-863b-4238-964b-e24425bb8107",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import numpy as np\n",
|
||||
"import tensorflow as tf"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "242c5eb2-2531-481d-8c21-0b0178ac2db5",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"np.random.seed(1793)\n",
|
||||
"tf.random.set_seed(1793)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "49fa101c-15a9-47be-b254-6cbde6dd04e9",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import trieste\n",
|
||||
"from trieste.utils.objectives import branin\n",
|
||||
"\n",
|
||||
"search_space = trieste.space.Box([0, 0], [1, 1])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "050686af-8c93-4996-b852-c925d7cc7e98",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def new_branin(data):\n",
|
||||
" print(data.shape)\n",
|
||||
" return branin(data)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"id": "cd44cddb-5090-4e4f-b651-f7d57f145b70",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"(5, 2)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from trieste.acquisition.rule import OBJECTIVE\n",
|
||||
"\n",
|
||||
"observer = trieste.utils.objectives.mk_observer(new_branin, OBJECTIVE)\n",
|
||||
"\n",
|
||||
"num_initial_points = 5\n",
|
||||
"initial_query_points = search_space.sample(num_initial_points)\n",
|
||||
"initial_data = observer(initial_query_points)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"id": "696dfd67-0e37-42f0-af5b-64937bffc2aa",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Building a model\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import gpflow\n",
|
||||
"\n",
|
||||
"def build_model(data):\n",
|
||||
" print(\"Building a model\")\n",
|
||||
" variance = tf.math.reduce_variance(data.observations)\n",
|
||||
" kernel = gpflow.kernels.Matern52(variance=variance, lengthscales=[0.2, 0.2])\n",
|
||||
" gpr = gpflow.models.GPR(data.astuple(), kernel, noise_variance=1e-5)\n",
|
||||
" gpflow.set_trainable(gpr.likelihood, False)\n",
|
||||
"\n",
|
||||
" return {OBJECTIVE: {\n",
|
||||
" \"model\": gpr,\n",
|
||||
" \"optimizer\": gpflow.optimizers.Scipy(),\n",
|
||||
" \"optimizer_args\": {\n",
|
||||
" \"minimize_args\": {\"options\": dict(maxiter=100)},\n",
|
||||
" },\n",
|
||||
" }}\n",
|
||||
"\n",
|
||||
"model = build_model(initial_data[OBJECTIVE])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"id": "5fa0dabe-9930-45bc-86cf-843e1dbe10a1",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"(1, 2)\n",
|
||||
"(1, 2)\n",
|
||||
"(1, 2)\n",
|
||||
"(1, 2)\n",
|
||||
"(1, 2)\n",
|
||||
"(1, 2)\n",
|
||||
"(1, 2)\n",
|
||||
"(1, 2)\n",
|
||||
"(1, 2)\n",
|
||||
"(1, 2)\n",
|
||||
"(1, 2)\n",
|
||||
"(1, 2)\n",
|
||||
"(1, 2)\n",
|
||||
"(1, 2)\n",
|
||||
"(1, 2)\n",
|
||||
"(1, 2)\n",
|
||||
"(1, 2)\n",
|
||||
"(1, 2)\n",
|
||||
"(1, 2)\n",
|
||||
"(1, 2)\n",
|
||||
"(1, 2)\n",
|
||||
"(1, 2)\n",
|
||||
"(1, 2)\n",
|
||||
"(1, 2)\n",
|
||||
"(1, 2)\n",
|
||||
"Optimization completed without errors\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"bo = trieste.bayesian_optimizer.BayesianOptimizer(observer, search_space)\n",
|
||||
"\n",
|
||||
"result = bo.optimize(25, initial_data, model)\n",
|
||||
"dataset = result.try_get_final_datasets()[OBJECTIVE]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"id": "7d1bcebb-8ac5-4b60-9589-922c2a5d5429",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"query point: [0.54162649 0.14935401]\n",
|
||||
"observation: [0.40162417]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"query_points = dataset.query_points.numpy()\n",
|
||||
"observations = dataset.observations.numpy()\n",
|
||||
"\n",
|
||||
"arg_min_idx = tf.squeeze(tf.argmin(observations, axis=0))\n",
|
||||
"\n",
|
||||
"print(f\"query point: {query_points[arg_min_idx, :]}\")\n",
|
||||
"print(f\"observation: {observations[arg_min_idx, :]}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"id": "297fc50a-f755-40f1-8c9f-aaed111951cf",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"<tf.Tensor: shape=(5, 1), dtype=float64, numpy=\n",
|
||||
"array([[ 52.74206955],\n",
|
||||
" [ 20.91833473],\n",
|
||||
" [110.23393455],\n",
|
||||
" [225.05283879],\n",
|
||||
" [ 52.49971937]])>"
|
||||
]
|
||||
},
|
||||
"execution_count": 12,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"initial_data[OBJECTIVE].observations"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "624720e0-5600-4a74-a2ce-6d09e248c069",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.3"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue