{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "25d01f84-863b-4238-964b-e24425bb8107", "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import tensorflow as tf" ] }, { "cell_type": "code", "execution_count": 2, "id": "242c5eb2-2531-481d-8c21-0b0178ac2db5", "metadata": {}, "outputs": [], "source": [ "np.random.seed(1793)\n", "tf.random.set_seed(1793)" ] }, { "cell_type": "code", "execution_count": 3, "id": "49fa101c-15a9-47be-b254-6cbde6dd04e9", "metadata": {}, "outputs": [], "source": [ "import trieste\n", "from trieste.utils.objectives import branin\n", "\n", "search_space = trieste.space.Box([0, 0], [1, 1])" ] }, { "cell_type": "code", "execution_count": 4, "id": "050686af-8c93-4996-b852-c925d7cc7e98", "metadata": {}, "outputs": [], "source": [ "def new_branin(data):\n", " print(data.shape)\n", " return branin(data)" ] }, { "cell_type": "code", "execution_count": 5, "id": "cd44cddb-5090-4e4f-b651-f7d57f145b70", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(5, 2)\n" ] } ], "source": [ "from trieste.acquisition.rule import OBJECTIVE\n", "\n", "observer = trieste.utils.objectives.mk_observer(new_branin, OBJECTIVE)\n", "\n", "num_initial_points = 5\n", "initial_query_points = search_space.sample(num_initial_points)\n", "initial_data = observer(initial_query_points)" ] }, { "cell_type": "code", "execution_count": 6, "id": "696dfd67-0e37-42f0-af5b-64937bffc2aa", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Building a model\n" ] } ], "source": [ "import gpflow\n", "\n", "def build_model(data):\n", " print(\"Building a model\")\n", " variance = tf.math.reduce_variance(data.observations)\n", " kernel = gpflow.kernels.Matern52(variance=variance, lengthscales=[0.2, 0.2])\n", " gpr = gpflow.models.GPR(data.astuple(), kernel, noise_variance=1e-5)\n", " gpflow.set_trainable(gpr.likelihood, False)\n", "\n", " return {OBJECTIVE: {\n", " \"model\": gpr,\n", " \"optimizer\": gpflow.optimizers.Scipy(),\n", " \"optimizer_args\": {\n", " \"minimize_args\": {\"options\": dict(maxiter=100)},\n", " },\n", " }}\n", "\n", "model = build_model(initial_data[OBJECTIVE])" ] }, { "cell_type": "code", "execution_count": 7, "id": "5fa0dabe-9930-45bc-86cf-843e1dbe10a1", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(1, 2)\n", "(1, 2)\n", "(1, 2)\n", "(1, 2)\n", "(1, 2)\n", "(1, 2)\n", "(1, 2)\n", "(1, 2)\n", "(1, 2)\n", "(1, 2)\n", "(1, 2)\n", "(1, 2)\n", "(1, 2)\n", "(1, 2)\n", "(1, 2)\n", "(1, 2)\n", "(1, 2)\n", "(1, 2)\n", "(1, 2)\n", "(1, 2)\n", "(1, 2)\n", "(1, 2)\n", "(1, 2)\n", "(1, 2)\n", "(1, 2)\n", "Optimization completed without errors\n" ] } ], "source": [ "bo = trieste.bayesian_optimizer.BayesianOptimizer(observer, search_space)\n", "\n", "result = bo.optimize(25, initial_data, model)\n", "dataset = result.try_get_final_datasets()[OBJECTIVE]" ] }, { "cell_type": "code", "execution_count": 8, "id": "7d1bcebb-8ac5-4b60-9589-922c2a5d5429", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "query point: [0.54162649 0.14935401]\n", "observation: [0.40162417]\n" ] } ], "source": [ "query_points = dataset.query_points.numpy()\n", "observations = dataset.observations.numpy()\n", "\n", "arg_min_idx = tf.squeeze(tf.argmin(observations, axis=0))\n", "\n", "print(f\"query point: {query_points[arg_min_idx, :]}\")\n", "print(f\"observation: {observations[arg_min_idx, :]}\")" ] }, { "cell_type": "code", "execution_count": 12, "id": "297fc50a-f755-40f1-8c9f-aaed111951cf", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "initial_data[OBJECTIVE].observations" ] }, { "cell_type": "code", "execution_count": null, "id": "624720e0-5600-4a74-a2ce-6d09e248c069", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.3" } }, "nbformat": 4, "nbformat_minor": 5 }