{ "cells": [ { "cell_type": "markdown", "id": "2c8984e0-0792-4cf8-b3c6-446b45b717f2", "metadata": {}, "source": [ "# Embedding models\n", "\n", "[![Binder](https://mybinder.org/badge_logo.svg)](https://mybinder.org/v2/gh/etna-team/etna/master?filepath=examples/210-embedding_models.ipynb)" ] }, { "cell_type": "markdown", "id": "94e7669f-de54-4df8-86ba-aa72c6d5fb55", "metadata": {}, "source": [ "This notebooks contains examples with embedding models.\n", "\n", "**Table of contents**\n", "\n", "* [Using embedding models directly](#chapter1) \n", "* [Using embedding models with transforms](#chapter2)\n", " * [Baseline](#section_2_1)\n", " * [EmbeddingSegmentTransform](#section_2_2)\n", " * [EmbeddingWindowTransform](#section_2_3)\n", "* [Saving and loading models](#chapter3)\n", "* [Loading external pretrained models](#chapter4)" ] }, { "cell_type": "code", "execution_count": 1, "id": "bf32c6a9-f920-4888-ac9d-f4a1c454cd91", "metadata": { "tags": [] }, "outputs": [], "source": [ "import warnings\n", "\n", "warnings.filterwarnings(\"ignore\")" ] }, { "cell_type": "markdown", "id": "d732e5b1-2c10-4de3-93ce-c6395ddbd4f1", "metadata": {}, "source": [ "## 1. Using embedding models directly " ] }, { "cell_type": "markdown", "id": "4c63da5a-eed8-472b-9786-9884a5bb78d1", "metadata": {}, "source": [ "We have two models to generate embeddings for time series: `TS2VecEmbeddingModel` and `TSTCCEmbeddingModel`.\n", "\n", "Each model has following methods:\n", "\n", "- `fit` to train model:\n", "- `encode_segment` to generate embeddings for the whole series. These features are regressors.\n", "- `encode_window` to generate embeddings for each timestamp. These features aren't regressors and lag transformation should be applied to them before using in forecasting.\n", "- `freeze` to enable or disable skipping training in `fit` method. It is useful, for example, when you have a pretrained model and you want only to generate embeddings without new training during `backtest`.\n", "- `save` and `load` to save and load pretrained models, respectively." ] }, { "cell_type": "code", "execution_count": 2, "id": "d5ec9757-dd5a-423c-9be1-e4835b4b2a03", "metadata": { "tags": [] }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[33mWARNING\u001b[0m Disabling SSL verification. Connections to this server are not verified and may be insecure!\n", "Global seed set to 42\n" ] }, { "data": { "text/plain": [ "42" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from pytorch_lightning import seed_everything\n", "\n", "seed_everything(42, workers=True)" ] }, { "cell_type": "code", "execution_count": 3, "id": "f99c90c5-8a8b-481a-848f-ebcb00b22bb0", "metadata": { "tags": [] }, "outputs": [ { "data": { "text/html": [ "
segment | \n", "segment_0 | \n", "segment_1 | \n", "segment_2 | \n", "
---|---|---|---|
feature | \n", "target | \n", "target | \n", "target | \n", "
timestamp | \n", "\n", " | \n", " | \n", " |
2001-01-01 | \n", "1.624345 | \n", "1.462108 | \n", "-1.100619 | \n", "
2001-01-02 | \n", "1.012589 | \n", "-0.598033 | \n", "0.044105 | \n", "
2001-01-03 | \n", "0.484417 | \n", "-0.920450 | \n", "0.945695 | \n", "
2001-01-04 | \n", "-0.588551 | \n", "-1.304504 | \n", "1.448190 | \n", "
2001-01-05 | \n", "0.276856 | \n", "-0.170735 | \n", "2.349046 | \n", "
segment | \n", "M1000_MACRO | \n", "M1001_MACRO | \n", "M1002_MACRO | \n", "M1003_MACRO | \n", "M1004_MACRO | \n", "M1005_MACRO | \n", "M1006_MACRO | \n", "M1007_MACRO | \n", "M1008_MACRO | \n", "M1009_MACRO | \n", "... | \n", "M992_MACRO | \n", "M993_MACRO | \n", "M994_MACRO | \n", "M995_MACRO | \n", "M996_MACRO | \n", "M997_MACRO | \n", "M998_MACRO | \n", "M999_MACRO | \n", "M99_MICRO | \n", "M9_MICRO | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
feature | \n", "target | \n", "target | \n", "target | \n", "target | \n", "target | \n", "target | \n", "target | \n", "target | \n", "target | \n", "target | \n", "... | \n", "target | \n", "target | \n", "target | \n", "target | \n", "target | \n", "target | \n", "target | \n", "target | \n", "target | \n", "target | \n", "
timestamp | \n", "\n", " | \n", " | \n", " | \n", " | \n", " | \n", " | \n", " | \n", " | \n", " | \n", " | \n", " | \n", " | \n", " | \n", " | \n", " | \n", " | \n", " | \n", " | \n", " | \n", " | \n", " |
0 | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "... | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "
1 | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "... | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "
2 | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "... | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "
3 | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "... | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "
4 | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "... | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "
5 rows × 1428 columns
\n", "segment | \n", "m1 | \n", "m10 | \n", "m100 | \n", "m101 | \n", "m102 | \n", "m103 | \n", "m104 | \n", "m105 | \n", "m106 | \n", "m107 | \n", "... | \n", "m90 | \n", "m91 | \n", "m92 | \n", "m93 | \n", "m94 | \n", "m95 | \n", "m96 | \n", "m97 | \n", "m98 | \n", "m99 | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
feature | \n", "target | \n", "target | \n", "target | \n", "target | \n", "target | \n", "target | \n", "target | \n", "target | \n", "target | \n", "target | \n", "... | \n", "target | \n", "target | \n", "target | \n", "target | \n", "target | \n", "target | \n", "target | \n", "target | \n", "target | \n", "target | \n", "
timestamp | \n", "\n", " | \n", " | \n", " | \n", " | \n", " | \n", " | \n", " | \n", " | \n", " | \n", " | \n", " | \n", " | \n", " | \n", " | \n", " | \n", " | \n", " | \n", " | \n", " | \n", " | \n", " |
0 | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "... | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "
1 | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "... | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "
2 | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "... | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "NaN | \n", "
3 | \n", "NaN | \n", "NaN | \n", "4.0 | \n", "329.0 | \n", "1341.0 | \n", "319.0 | \n", "1419.0 | \n", "462.0 | \n", "921.0 | \n", "3118.0 | \n", "... | \n", "7301.0 | \n", "4374.0 | \n", "803.0 | \n", "191.0 | \n", "124.0 | \n", "319.0 | \n", "270.0 | \n", "36.0 | \n", "109.0 | \n", "38.0 | \n", "
4 | \n", "NaN | \n", "NaN | \n", "40.0 | \n", "439.0 | \n", "1258.0 | \n", "315.0 | \n", "1400.0 | \n", "550.0 | \n", "1060.0 | \n", "2775.0 | \n", "... | \n", "13980.0 | \n", "3470.0 | \n", "963.0 | \n", "265.0 | \n", "283.0 | \n", "690.0 | \n", "365.0 | \n", "31.0 | \n", "158.0 | \n", "74.0 | \n", "
5 rows × 366 columns
\n", "