model.py

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
# SPDX-License-Identifier: AGPL-3.0-or-later
# Copyright (C) 2025 SWGY, Inc
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
"""
Pressure Wave Simulation
========================

This module simulates pressure waves reflecting off geometry and calculates
the pressure effects on sensors. It uses a ray-based approach with specular
reflections to model blast wave propagation in complex environments.
"""

from __future__ import annotations
from dataclasses import dataclass
from mpl_toolkits.mplot3d import Axes3D  # Required for 3D plotting
from typing import List
from typing import List, Tuple, Optional, Callable, Sequence, Union
import logging  # Added for logging
import numpy as np
import trimesh

# Type definitions using modern dataclasses
@dataclass
class Sensor:
    """A sensor represented as a sphere with an origin and radius."""
    origin: np.ndarray
    radius: float


@dataclass
class Extents:
    """Angular extents (in radians) from a viewpoint."""
    left: float
    right: float
    top: float
    bottom: float


@dataclass
class PressureStats:
    """Statistics about pressure at a point."""
    peak_overpressure: float
    total_impulse: float


# Vector operations
def calculate_distance(point1, point2):
    return np.linalg.norm(np.array(point1) - np.array(point2))

def unit_vector(v: np.ndarray) -> np.ndarray:
    """
    Convert a vector to a unit vector.

    Args:
        v: Input vector

    Returns:
        Normalized unit vector or zero vector if norm is too small
    """
    norm = np.linalg.norm(v)
    return v / norm if norm >= 1e-12 else np.zeros_like(v)


def reflect(direction: np.ndarray, normal: np.ndarray) -> np.ndarray:
    """
    Compute the specular reflection of a direction vector about a normal.

    Args:
        direction: Incident direction vector (should be normalized)
        normal: Surface normal vector (should be normalized)

    Returns:
        Reflected direction vector
    """
    return direction - 2.0 * np.dot(direction, normal) * normal


def create_coordinate_frame(forward: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    """
    Create a coordinate frame with the given forward vector.

    Args:
        forward: Forward direction vector

    Returns:
        Tuple of (forward, right, up) unit vectors forming an orthonormal basis
    """
    forward = unit_vector(forward)

    # Choose an up vector that's not parallel to forward
    up = np.array([0.0, 0.0, 1.0])
    if abs(np.dot(forward, up)) > 0.99:
        up = np.array([0.0, 1.0, 0.0])

    right = unit_vector(np.cross(forward, up))
    up = unit_vector(np.cross(right, forward))

    return forward, right, up


# Intersection calculations
def line_sphere_intersect(
        p0: np.ndarray, 
        p1: np.ndarray, 
        center: np.ndarray, 
        radius: float
        ) -> Tuple[bool, Optional[np.ndarray], Optional[float]]:
    """
    Determine if a line segment intersects a sphere.

    Args:
        p0: Start point of line segment
        p1: End point of line segment
        center: Center of the sphere
        radius: Radius of the sphere

    Returns:
        Tuple containing (hit_flag, hit_point, distance):
        - hit_flag: True if an intersection occurs
        - hit_point: The intersection point (None if no intersection)
        - distance: Distance from p0 to intersection (None if no intersection)
    """
    p0, p1, center = map(np.asarray, (p0, p1, center))

    # Calculate quadratic equation coefficients
    d = p1 - p0
    A = np.dot(d, d)
    B = 2.0 * np.dot(p0 - center, d)
    C = np.dot(p0 - center, p0 - center) - radius ** 2

    # Check for valid intersection
    disc = B ** 2 - 4 * A * C
    if disc < 0 or abs(A) < 1e-14:
        return False, None, None

    # Find the closest valid intersection within the line segment
    sqrt_disc = np.sqrt(disc)
    t_candidates = sorted([(-B - sqrt_disc) / (2 * A), (-B + sqrt_disc) / (2 * A)])

    for t in t_candidates:
        if 0 <= t <= 1:
            hit_point = p0 + t * d
            return True, hit_point, np.linalg.norm(hit_point - p0)

    return False, None, None


def find_mesh_intersection(
        geometry: List[trimesh.base.Trimesh], 
        origin: np.ndarray, 
        direction: np.ndarray
        ) -> Tuple[bool, Optional[np.ndarray], Optional[float], Optional[np.ndarray]]:
    """
    Find the closest intersection between a ray and a list of meshes.

    Args:
        geometry: List of trimesh meshes
        origin: Ray origin
        direction: Ray direction

    Returns:
        Tuple containing (hit_flag, hit_point, hit_distance, hit_normal):
        - hit_flag: True if an intersection was found
        - hit_point: Point of intersection (None if no intersection)
        - hit_distance: Distance to intersection (None if no intersection)
        - hit_normal: Surface normal at intersection (None if no intersection)
    """
    hit_distance = np.inf
    hit_point = None
    hit_normal = None

    for mesh in geometry:
        locations, _, indices = mesh.ray.intersects_location(
                ray_origins=[origin],
                ray_directions=[direction],
                multiple_hits=False
                )

        if locations is not None and locations.shape[0] > 0:
            distance = np.linalg.norm(locations[0] - origin)
            if distance < hit_distance:
                hit_distance = distance
                hit_point = locations[0]
                hit_normal = mesh.face_normals[indices[0]]

    if hit_point is not None:
        return True, hit_point, hit_distance, hit_normal

    return False, None, None, None


# Waveform modeling
def friedlander_waveform(
        t: Union[float, np.ndarray], 
        p_max: float, 
        t_d: float, 
        alpha: float = 1.0
        ) -> Union[float, np.ndarray]:
    """
    Compute the Friedlander waveform, a model of blast overpressure vs. time.

    The Friedlander waveform is defined as:
    p(t) = p_max * (1 - t/t_d) * exp(-alpha * t/t_d) for t >= 0,
    and p(t) = 0 for t < 0.

    Args:
        t: Time or array of times
        p_max: Peak overpressure
        t_d: Positive phase duration
        alpha: Decay parameter

    Returns:
        Pressure value(s) at the given time(s)
    """
    t_array = np.asarray(t)
    p = np.zeros_like(t_array, dtype=float)

    mask = t_array >= 0
    p[mask] = p_max * (1 - t_array[mask] / t_d) * np.exp(-alpha * t_array[mask] / t_d)

    # Clip negative pressures to zero. Although underpressure is a phenomenon,
    # we don't explore it in this model

    p[p < 0] = 0.0

    return p



def define_sensor(origin: Sequence[float], radius: float) -> Sensor:
    """
    Define a sensor as a sphere with the given origin and radius.

    Args:
        origin: 3D coordinates of the sensor center
        radius: Radius of the sensor

    Returns:
        Sensor object

    Raises:
        ValueError: If origin doesn't have three components or radius is not positive
    """
    if len(origin) != 3:
        raise ValueError("Origin must have three components.")
    if radius <= 0:
        raise ValueError("Radius must be positive.")

    return Sensor(np.array(origin), radius)


def compute_extents(geometry, blast_point, sensor):
    """
    Compute the angular extents from the blast point to the limits of the
    geometry. The sensor will always be included within the extents.

    Args:
        geometry: List of meshes to include in the calculations
        blast_point: Location (x, y, z) of the blast.
        sensor: Sensor object

    Returns:
        Extents object for the left, right, top, and bottom angular extents

    """
    forward, right, up = create_coordinate_frame(sensor.origin - blast_point)

    # Compute geometry extents if geometry exists
    if geometry:
        vertices = np.concatenate([mesh.vertices for mesh in geometry], axis=0)
        rel_vertices = vertices - blast_point
        horizontal_angles = np.arctan2(np.dot(rel_vertices, right), np.dot(rel_vertices, forward))
        vertical_angles = np.arctan2(np.dot(rel_vertices, up), np.dot(rel_vertices, forward))

        geom_left = horizontal_angles.min()
        geom_right = horizontal_angles.max()
        geom_bottom = vertical_angles.min()
        geom_top = vertical_angles.max()
    else:
        geom_left, geom_right, geom_bottom, geom_top = np.inf, -np.inf, np.inf, -np.inf

    # Sensor angular extents
    sensor_rel = sensor.origin - blast_point
    sensor_distance = np.linalg.norm(sensor_rel)
    sensor_half_angle = np.pi / 2 if sensor_distance <= sensor.radius else np.arcsin(sensor.radius / sensor_distance)

    sensor_horiz = np.arctan2(np.dot(sensor_rel, right), np.dot(sensor_rel, forward))
    sensor_vert = np.arctan2(np.dot(sensor_rel, up), np.dot(sensor_rel, forward))

    sensor_left = sensor_horiz - sensor_half_angle
    sensor_right = sensor_horiz + sensor_half_angle
    sensor_bottom = sensor_vert - sensor_half_angle
    sensor_top = sensor_vert + sensor_half_angle

    # Explicitly ensure sensor extents always included
    left = min(sensor_left, geom_left) if geom_left != np.inf else sensor_left
    right = max(sensor_right, geom_right) if geom_right != -np.inf else sensor_right
    bottom = min(sensor_bottom, geom_bottom) if geom_bottom != np.inf else sensor_bottom
    top = max(sensor_top, geom_top) if geom_top != -np.inf else sensor_top

    return Extents(left, right, top, bottom)

# Ray-based pressure calculation
def trace_ray(
        ray_dir: np.ndarray,
        geometry: List[trimesh.base.Trimesh], 
        blast_point: np.ndarray, 
        p0: float, 
        sensor: Sensor, 
        max_bounces: int = 3, 
        max_length: float = 50.0
        ) -> PressureStats:
    """
    Trace a single ray from blast_point in direction ray_dir and compute the
    cumulative pressure at the sensor.

    Args:
        ray_dir: Direction of the ray
        geometry: List of trimesh meshes
        blast_point: Blast source coordinate
        p0: Peak pressure at the blast (in KPa)
        sensor: Sensor object
        max_bounces: Maximum number of reflections to calculate
        max_length: Maximum tracing distance (m)

    Returns:
        PressureStats with peak overpressure and total impulse
    """
    current_point = np.array(blast_point, dtype=float)
    current_direction = unit_vector(ray_dir)
    current_distance = 0.0
    bounces = 0
    reflection_factor = 1.0

    overall_peak = 0.0
    overall_impulse = 0.0

    # Blast wave parameters
    t_d = 0.01              # Positive phase duration (s)
    alpha = 1.0             # Decay parameter
    speed_of_sound = 350.0  # m/s

    while current_distance < max_length and bounces <= max_bounces:
        remaining_distance = max_length - current_distance
        segment_end = current_point + current_direction * remaining_distance

        # Check intersection with sensor sphere
        hit_sensor, sensor_hit_point, sensor_hit_distance = line_sphere_intersect(
                current_point, segment_end, sensor.origin, sensor.radius
                )

        # Check intersection with geometry
        geom_hit, geom_hit_point, geom_hit_distance, geom_hit_normal = find_mesh_intersection(
                geometry, current_point, current_direction
                )

        # Ray hits sensor before geometry (or without hitting geometry)
        if hit_sensor and (not geom_hit or sensor_hit_distance < geom_hit_distance):
            effective_distance = current_distance + sensor_hit_distance
            effective_distance = max(effective_distance, 1e-6)  # Avoid division by zero

            # Calculate pressure using inverse square law with reflection factor
            p_eff = reflection_factor * p0 / (effective_distance ** 2)

            # Sample the pressure waveform to compute impulse
            t_samples = np.linspace(0, 3 * t_d, 1000)
            pressures = friedlander_waveform(t_samples, p_eff, t_d, alpha)
            impulse = np.trapz(pressures, t_samples)

            overall_peak = max(overall_peak, p_eff)
            overall_impulse += impulse
            break
        # Ray hits geometry before sensor
        elif geom_hit:
            current_point = geom_hit_point
            current_distance += geom_hit_distance

            if current_distance >= max_length:
                break

            # Calculate reflection and update attributes
            current_direction = reflect(current_direction, geom_hit_normal)
            current_direction = unit_vector(current_direction)
            reflection_factor *= 0.8  # Energy loss per bounce
            bounces += 1
        else:
            # Ray doesn't hit anything
            break

    return PressureStats(peak_overpressure=overall_peak, total_impulse=overall_impulse)



def simulate_pressure(
        geometry: List[trimesh.base.Trimesh],
        blast_point: np.ndarray,
        sensor: Sensor,
        extents: Extents,
        resolution_arcmin: float,
        ray_callback: Callable,
        p0: float,
        max_length: float,
        max_bounces: int
        ) -> PressureStats:
    """
    Compute the cumulative pressure at a sensor by tracing rays evenly over angular extents.

    Args:
        geometry: List of trimesh meshes (can be empty)
        blast_point: Blast source coordinate
        sensor: Sensor object
        extents: Angular extents to trace
        resolution_arcmin: Angular resolution in arc-minutes
        ray_callback: Function to compute pressure for a given ray
        p0: Peak pressure at the blast (in KPa)
        max_length: Maximum tracing distance (m)
        max_bounces: Stop calculating when a ray has bounced more than this

    Returns:
        PressureStats with cumulative peak overpressure and total impulse
    """
    # Convert resolution from arc-minutes to radians
    res_rad = resolution_arcmin * (np.pi / 10800)  # 1 arc-minute = (pi/180)/60

    # Always use sensor direction for the local coordinate frame
    forward, right, up = create_coordinate_frame(sensor.origin - blast_point)

    overall_peak = 0.0
    overall_impulse = 0.0

    # Evenly distribute rays within provided angular extents
    azimuth_angles = np.arange(extents.left, extents.right + res_rad, res_rad)
    elevation_angles = np.arange(extents.bottom, extents.top + res_rad, res_rad)

    total_rays = len(azimuth_angles) * len(elevation_angles)
    logging.debug(f"Starting simulation: {total_rays} ray calculations will run")
    ray_count = 0

    for el in elevation_angles:
        for az in azimuth_angles:
            ray_count += 1
            # Convert angular coordinates to direction vector
            ray_dir = (
                    forward * np.cos(el) * np.cos(az) +
                    right * np.cos(el) * np.sin(az) +
                    up * np.sin(el)
                    )

            # Trace ray and compute pressure stats
            stats = ray_callback(
                    ray_dir, geometry, blast_point, p0, sensor,
                    max_bounces=max_bounces, max_length=max_length
                    )

            # Periodic progress logging
            if ray_count % 500 == 0:
                logging.debug(f"Ray {ray_count}/{total_rays} finished")

            # Weight contribution by differential solid angle
            dOmega = np.cos(el) * (res_rad ** 2)

            # Update cumulative statistics
            overall_peak = max(overall_peak, stats.peak_overpressure)
            overall_impulse += stats.total_impulse * dOmega

    return PressureStats(peak_overpressure=overall_peak, total_impulse=overall_impulse)