{ "cells": [ { "cell_type": "markdown", "id": "c4e1eaf2a7aafd", "metadata": {}, "source": [ "# Advanced code examples" ] }, { "cell_type": "markdown", "id": "fdfffce2", "metadata": {}, "source": [ "This notebook covers three advanced topics:\n", "\n", "1. **Model validation and mass balance** — checking that a CM is physically consistent before using it.\n", "2. **Multi-pool fitting** — simultaneously fitting labeling data from multiple distinct observable pools to a single shared CM.\n", "3. **Model selection with BIC** — using the Bayesian Information Criterion to choose the appropriate number of states without overfitting." ] }, { "cell_type": "markdown", "id": "dd342436135dcc33", "metadata": {}, "source": [ "## Model validation and mass-balance" ] }, { "metadata": {}, "cell_type": "code", "source": [ "from symbolic_compartmental_model import SymbolicCompartmentalModel\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "import seaborn as sns\n", "from pathlib import Path\n", "\n", "sns.set_style(\"whitegrid\")\n", "\n", "# Create an invalid model that violates mass balance\n", "cm = SymbolicCompartmentalModel(n_states=2)\n", "cm.contributed_turnovers = [[-1, 2], [2, -2]]\n", "cm.observed_pool_weights = [0.9, 0.1]\n", "cm.growth_rate = 3.0\n", "\n", "# Check if the model is valid (it is not because the sum of the first row is positive)\n", "if cm.is_valid(raise_exception=False):\n", " print(\"Model is valid!\")\n", "else:\n", " print(\"Model is invalid!\")\n", " print(\"row sums = \", np.sum(cm.M(), axis=1), \" should all be non-positive\\n\\n\")\n", "\n", "# Fix the model validity\n", "cm.contributed_turnovers = [[-1, 0], [2, -2]]\n", "\n", "# Check again if the model is valid (now it should be)\n", "if cm.is_valid(raise_exception=False):\n", " print(\"Model is valid!\")\n", " print(\"row sums = \", np.sum(cm.M(), axis=1), \" are all non-positive\\n\\n\")\n", "else:\n", " print(\"Model is invalid!\")\n", "\n", "# Check that the model is M-mass balanced (it is not)\n", "if cm.is_M_mass_balanced(raise_exception=False):\n", " print(\"Model is M-mass balanced!\")\n", "else:\n", " print(\"Model is not M-mass balanced!\")\n", " print(\"abscissa = \", cm.abscissa(), \" should be smaller or equal to -μ = \", -cm.growth_rate, \"\\n\\n\")\n", "cm.growth_rate = 1.0\n", "\n", "# Check that the model is M-mass balanced (now it is)\n", "if cm.is_M_mass_balanced(raise_exception=False):\n", " print(\"Model is M-mass balanced!\")\n", " print(\"abscissa = \", cm.abscissa(), \" is smaller or equal to -μ = \", -cm.growth_rate, \"\\n\\n\")\n", "else:\n", " print(\"Model is not M-mass balanced!\")\n", "\n", "\n", "# Check if the model is mass-balanced\n", "if cm.is_mass_balanced(raise_exception=False):\n", " print(\"Model is mass balanced!\")\n", "else:\n", " print(\"Model is not mass balanced!\")\n", " print(\"s'(M + Iμ) = \", cm.s().T @ cm._M_plus_mu_I(), \" should all be non-positive\\n\\n\")\n", "\n", "# Fix the model mass-balancing\n", "cm.growth_rate = 0.7\n", "\n", "# Check again if the model is mass-balanced\n", "if cm.is_mass_balanced(raise_exception=False):\n", " print(\"Model is mass balanced!\")\n", " print(\"s'(M + Iμ) = \", cm.s().T @ cm._M_plus_mu_I(), \" are all non-positive\\n\\n\")\n", "else:\n", " print(\"Model is not mass balanced!\")\n" ], "id": "26ebf4772c3eab8a", "outputs": [], "execution_count": null }, { "cell_type": "markdown", "id": "42ddc596", "metadata": {}, "source": [ "A physically meaningful CM must satisfy 3 conditions:\n", "\n", "**Validity**: the `contributed_turnovers` ($\\mathbf{M}$-matrix) must obey the condition that all rows have a non-positive sum (i.e., the diagonal value, which is negative, must be larger in absolute value than the sum of all the other values in the row).\n", "\n", "**Mass balance**: at steady-state the total efflux from the each state must larger or equal to the dilution by growth. Formally, the condition can be written as: $$\\mathbf{s}^\\top (\\mathbf{M} + \\mu \\mathbf{I}_n) \\leq \\mathbf{0}_n$$\n", "where $\\mathbf{s}$ are the `observed_pool_weights`, $\\mu$ is the growth rate, $\\mathbf{I}_n$ is the identity matrix of size $n$, and $\\mathbf{0}_n$ is a vector of zeros.\n", "\n", "**M-mass balance**: for a CM to be M-mass balanced, the M-matrix must satisfy the condition that the abscissa (the largest eigenvalue, which is a negative number) must be smaller or equal to $-\\mu$. This condition is necessary and sufficient for the existence of an assignment for $\\mathbf{s}$ that satifies the mass balance condition." ] }, { "cell_type": "markdown", "id": "23320424af2faae5", "metadata": {}, "source": [ "## Fitting multiple pool data to a multi-state CM" ] }, { "metadata": {}, "cell_type": "code", "outputs": [], "execution_count": null, "source": [ "# Create a more complex model with parameters\n", "cm_multi = SymbolicCompartmentalModel(n_states=3)\n", "\n", "k1 = cm_multi.add_parameter(symbol=\"k1\", lb=0.1, ub=5.0)\n", "k2 = cm_multi.add_parameter(symbol=\"k2\", lb=0.1, ub=10.0)\n", "k3 = cm_multi.add_parameter(symbol=\"k3\", lb=0.1, ub=2.0)\n", "\n", "cm_multi.contributed_turnovers = [[-k1, 0, 0], [k1, -k2, 0], [0, k2, -k3]]\n", "cm_multi.observed_pool_weights = [0.2, 0.3, 0.5]\n", "\n", "# Multi-pool experimental data: (pool_id, time, measurement)\n", "multi_pool_data = [\n", " (0, 1.0, 0.8), (0, 2.0, 0.6), (0, 3.0, 0.4),\n", " (1, 1.0, 0.9), (1, 2.0, 0.7), (1, 3.0, 0.5),\n", " (2, 1.0, 0.95), (2, 2.0, 0.85), (2, 3.0, 0.7)\n", "]\n", "\n", "# Fit with mass balance constraint\n", "fit_results = cm_multi.fit_multiple_pools(multi_pool_data)\n", "print(fit_results)\n", "print(f\"RSS: {fit_results.rss:.6f}\")\n", "print(f\"Mean ages: {fit_results.cm.mean_ages()}\")\n", "\n", "fig, ax = plt.subplots(1, 1, figsize=(4, 4), dpi=150)\n", "\n", "colors = [\"cyan\", \"orange\", \"green\"]\n", "for idx, t, y in multi_pool_data:\n", " fit_results.cm.observed_pool_weights = np.eye(3)[idx,:]\n", " fit_results.cm.plot(key=\"f\", ax=ax, color=colors[idx], alpha=0.2)\n", " ax.plot(t, y, 'x', color=colors[idx])\n", "ax.set_xlabel(\"time, t\")\n", "ax.set_ylabel(\"labeling, f(t)\")\n", "if Path(\"../results\").exists():\n", " fig.savefig(\"../results/example_multiple_pools.svg\")" ], "id": "b3e71a7bceaae292" }, { "cell_type": "markdown", "id": "6b13bee5", "metadata": {}, "source": [ "When labeling data is available for multiple **observable pools** (e.g. different cellular fractions, subunits, or protein complexes), all measurements can be fit simultaneously to a single shared CM. Each data point is a triplet `(pool_id, time, measurement)`, where `pool_id` selects which linear combination of internal states defines the observable for that measurement.\n", "\n", "`fit_multiple_pools()` optimizes the shared kinetic parameters to minimize the total residual across all pools, ensuring one mechanistic model is consistent with every observation. The returned `FitResult` also exposes `mean_ages()` — the per-state mean age vector — in addition to the scalar mean age of the observed mixture." ] }, { "cell_type": "markdown", "id": "1cfa6d9e132efffd", "metadata": { "ExecuteTime": { "end_time": "2025-09-12T11:32:38.938091Z", "start_time": "2025-09-12T11:32:38.936395Z" } }, "source": [ "## Using BIC to choose between different models" ] }, { "metadata": {}, "cell_type": "code", "outputs": [], "execution_count": null, "source": [ "# Generate synthetic data with noise\n", "cm_ground_truth = SymbolicCompartmentalModel(n_states=2)\n", "cm_ground_truth.contributed_turnovers = [[-1.0, 0], [0.5, -0.5]]\n", "cm_ground_truth.observed_pool_weights = [0.2, 0.8]\n", "\n", "np.random.seed(42)\n", "t_data = np.array([0.1, 0.2, 0.3, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0])\n", "y_true = cm_ground_truth.f()(t_data)\n", "y_data = y_true + 0.02 * np.random.randn(len(y_true)) # Add noise\n", "\n", "# Try to fit a single pool CM to the data\n", "cm_one = SymbolicCompartmentalModel(n_states=1)\n", "k = cm_one.add_parameter(symbol=\"k\", lb=0.1, ub=10.0)\n", "cm_one.contributed_turnovers = [[-k]]\n", "\n", "# Add parameters with different constraints\n", "cm_two = SymbolicCompartmentalModel(n_states=2)\n", "k1 = cm_two.add_parameter(symbol=\"k1\", lb=0.1, ub=10.0)\n", "k2 = cm_two.add_parameter(symbol=\"k2\", lb=0.1, ub=10.0)\n", "w = cm_two.add_parameter(symbol=\"w\", lb=0, ub=1)\n", "cm_two.contributed_turnovers = [[-k1, 0.0], [k2, -k2]]\n", "cm_two.observed_pool_weights = [w, 1-w]\n", "\n", "# Use minimize instead of curve_fit for better control\n", "fit_results_one = cm_one.fit(\n", " tdata=t_data,\n", " ydata=y_data,\n", ")\n", "# Use minimize instead of curve_fit for better control\n", "fit_results_two = cm_two.fit(\n", " tdata=t_data,\n", " ydata=y_data,\n", ")\n", "\n", "print(\"1-state model\\n\", fit_results_one)\n", "print(\"\\n\\n2-state model\\n\", fit_results_two)\n", "\n", "fig, ax = plt.subplots(1, 1, figsize=(4, 4), dpi=150)\n", "fit_results_one.cm.plot(key=\"f\", ax=ax, t_range=np.linspace(0, 4, 100), color=\"red\", label=\"1-state\")\n", "fit_results_two.cm.plot(key=\"f\", ax=ax, t_range=np.linspace(0, 4, 100), color=\"green\", label=\"2-state\")\n", "ax.plot(t_data, y_true, 'x', label=\"experimental data\", color=\"blue\")\n", "ax.legend()\n", "ax.set_xlabel(\"time, t\")\n", "ax.set_ylabel(\"labeling, f(t)\")\n", "if Path(\"../results\").exists():\n", " fig.savefig(\"../results/example_bic.svg\")" ], "id": "82838b0ae9c1ae94" }, { "cell_type": "markdown", "id": "87a0a926", "metadata": {}, "source": [ "A more complex model always fits the training data better, but may overfit noise. For the Gaussian special case, the [**Bayesian Information Criterion (BIC)**](https://en.wikipedia.org/wiki/Bayesian_information_criterion#Gaussian_special_case) is defined as:\n", "\n", "$$\\text{BIC} = m \\cdot \\ln(\\text{RSS} / m) + n \\cdot \\ln(m)$$\n", "\n", "where $n$ is the number of free parameters, $m$ is the number of observed data points, and $\\text{RSS}$ is the residual sum of squares between the fitted curve and the data points. A **lower BIC** indicates a better trade-off between fit quality and parsimony." ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "2.7.6" } }, "nbformat": 4, "nbformat_minor": 5 }