import os import uuid import cv2 import numpy as np from threading import Event, Lock from flask import ( Flask, render_template, Response, request, redirect, url_for, send_from_directory, session, jsonify, abort, ) from werkzeug.utils import secure_filename from ultralytics import YOLO app = Flask(__name__) app.secret_key = os.environ.get('SECRET_KEY', 'vehicle_dev_secret') app.config['MAX_CONTENT_LENGTH'] = 200 * 1024 * 1024 # 200MB upload cap # Load the YOLOv8 model model = YOLO("yolo11s.pt") names = model.model.names # Vehicle classes to count VEHICLE_CLASSES = {'car', 'truck', 'bus', 'motorcycle'} ALLOWED_EXTENSIONS = {'mp4', 'mov', 'avi', 'mkv'} UPLOAD_DIR = 'uploads' # Track reset events per stream (webcam/video per session) reset_events: dict[str, Event] = {} reset_lock = Lock() def allowed_file(filename: str) -> bool: return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS def ensure_upload_dir() -> None: if not os.path.exists(UPLOAD_DIR): os.makedirs(UPLOAD_DIR) def get_line_from_session(): if 'counting_line' not in session: session['counting_line'] = {'x1': 0, 'y1': 300, 'x2': 1020, 'y2': 300} return session['counting_line'] def fresh_vehicle_counts() -> dict[str, int]: return {vehicle: 0 for vehicle in VEHICLE_CLASSES} def get_reset_event(stream_id: str) -> Event: with reset_lock: event = reset_events.get(stream_id) if event is None: event = Event() reset_events[stream_id] = event return event def release_reset_event(stream_id: str) -> None: with reset_lock: reset_events.pop(stream_id, None) def get_webcam_stream_id() -> str: stream_id = session.get('webcam_stream_id') if not stream_id: stream_id = f'webcam-{uuid.uuid4().hex}' session['webcam_stream_id'] = stream_id return stream_id def get_video_stream_id(filename: str) -> str: video_streams = session.get('video_stream_ids', {}) stream_id = video_streams.get(filename) if not stream_id: stream_id = f'video-{uuid.uuid4().hex}' video_streams[filename] = stream_id session['video_stream_ids'] = video_streams return stream_id # Helper function to check if two line segments intersect def line_intersect(p1, p2, p3, p4): """ Check if line segment p1-p2 intersects with line segment p3-p4 Returns True if they intersect, False otherwise """ x1, y1 = p1 x2, y2 = p2 x3, y3 = p3 x4, y4 = p4 denom = (x1 - x2) * (y3 - y4) - (y1 - y2) * (x3 - x4) if abs(denom) < 1e-10: return False t = ((x1 - x3) * (y3 - y4) - (y1 - y3) * (x3 - x4)) / denom u = -((x1 - x2) * (y1 - y3) - (y1 - y2) * (x1 - x3)) / denom return 0 <= t <= 1 and 0 <= u <= 1 def crossed_line(prev_pos, curr_pos, line_start, line_end): """ Check if movement from prev_pos to curr_pos crossed the line. Uses orientation check - more robust for frame skipping. """ # Check if the two line segments intersect if line_intersect(prev_pos, curr_pos, line_start, line_end): return True return False def generate_frames(capture, line_data, stream_id: str): """ Shared frame generator for webcam and uploaded videos. Handles detection, drawing overlays, and reset events. """ track_positions = {} counted_ids = set() vehicle_count = 0 vehicle_type_counts = fresh_vehicle_counts() line_start = (line_data['x1'], line_data['y1']) line_end = (line_data['x2'], line_data['y2']) frame_idx = 0 reset_event = get_reset_event(stream_id) try: while True: ret, frame = capture.read() if not ret: break frame_idx += 1 if frame_idx % 2 != 0: continue if reset_event.is_set(): track_positions.clear() counted_ids.clear() vehicle_count = 0 vehicle_type_counts = fresh_vehicle_counts() reset_event.clear() frame = cv2.resize(frame, (1020, 600)) results = model.track(frame, persist=True) if results and results[0].boxes is not None and results[0].boxes.id is not None: boxes = results[0].boxes.xyxy.int().cpu().tolist() class_ids = results[0].boxes.cls.int().cpu().tolist() track_ids = results[0].boxes.id.int().cpu().tolist() for box, class_id, track_id in zip(boxes, class_ids, track_ids): label_name = names[class_id] x1, y1, x2, y2 = box center_x = (x1 + x2) // 2 center_y = (y1 + y2) // 2 if label_name in VEHICLE_CLASSES: if track_id in track_positions and track_id not in counted_ids: prev_x, prev_y = track_positions[track_id] cv2.line(frame, (prev_x, prev_y), (center_x, center_y), (255, 100, 0), 2) if crossed_line((prev_x, prev_y), (center_x, center_y), line_start, line_end): counted_ids.add(track_id) vehicle_count += 1 vehicle_type_counts[label_name] += 1 cv2.circle(frame, (center_x, center_y), 25, (0, 255, 0), 5) track_positions[track_id] = (center_x, center_y) box_color = (0, 255, 0) if label_name in VEHICLE_CLASSES else (255, 0, 0) cv2.rectangle(frame, (x1, y1), (x2, y2), box_color, 2) label = f'{track_id} - {label_name}' if track_id in counted_ids: label += ' ✓' cv2.putText(frame, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 255), 1) cv2.circle(frame, (center_x, center_y), 3, (0, 255, 255), -1) cv2.line(frame, line_start, line_end, (0, 255, 255), 3, cv2.LINE_AA) line_length = int(np.sqrt((line_end[0] - line_start[0]) ** 2 + (line_end[1] - line_start[1]) ** 2)) dash_length = 20 for i in range(0, max(line_length, 1), dash_length * 2): t1 = i / line_length if line_length else 0 t2 = min((i + dash_length) / line_length, 1.0) if line_length else 0 x1_dash = int(line_start[0] + t1 * (line_end[0] - line_start[0])) y1_dash = int(line_start[1] + t1 * (line_end[1] - line_start[1])) x2_dash = int(line_start[0] + t2 * (line_end[0] - line_start[0])) y2_dash = int(line_start[1] + t2 * (line_end[1] - line_start[1])) cv2.line(frame, (x1_dash, y1_dash), (x2_dash, y2_dash), (0, 0, 0), 3) cv2.rectangle(frame, (10, 10), (350, 140), (0, 0, 0), -1) cv2.putText(frame, f'Gesamt: {vehicle_count}', (20, 35), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2) cv2.putText(frame, f"Autos: {vehicle_type_counts.get('car', 0)}", (20, 65), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (200, 200, 200), 1) cv2.putText(frame, f"LKW: {vehicle_type_counts.get('truck', 0)}", (20, 90), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (200, 200, 200), 1) cv2.putText(frame, f"Busse: {vehicle_type_counts.get('bus', 0)}", (20, 115), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (200, 200, 200), 1) cv2.putText(frame, f"Motorraeder: {vehicle_type_counts.get('motorcycle', 0)}", (20, 135), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (200, 200, 200), 1) _, buffer = cv2.imencode('.jpg', frame) frame_bytes = buffer.tobytes() yield (b'--frame\r\n' b'Content-Type: image/jpeg\r\n\r\n' + frame_bytes + b'\r\n') finally: capture.release() release_reset_event(stream_id) @app.route('/') def index(): return render_template('index.html') @app.route('/start_webcam') def start_webcam(): get_line_from_session() stream_id = get_webcam_stream_id() return render_template('webcam.html', stream_id=stream_id) @app.route('/api/set_line', methods=['POST']) def set_counting_line(): """API endpoint to set the counting line coordinates""" data = request.get_json(silent=True) or {} try: session['counting_line'] = { 'x1': int(data['x1']), 'y1': int(data['y1']), 'x2': int(data['x2']), 'y2': int(data['y2']) } except (KeyError, ValueError, TypeError): abort(400, description='Ungültige Linienkoordinaten') return jsonify({'status': 'success', 'line': session['counting_line']}) @app.route('/api/get_line', methods=['GET']) def get_counting_line(): """API endpoint to get the current counting line coordinates""" return jsonify(get_line_from_session()) @app.route('/api/reset_count', methods=['POST']) def reset_count(): """API endpoint to reset the vehicle count for a stream""" data = request.get_json(silent=True) or {} stream_id = data.get('stream_id') if not stream_id: abort(400, description='stream_id ist erforderlich') valid_ids = {session.get('webcam_stream_id')} valid_ids.update(session.get('video_stream_ids', {}).values()) valid_ids.discard(None) if stream_id not in valid_ids: abort(403, description='Stream gehört nicht zur aktuellen Sitzung') event = get_reset_event(stream_id) event.set() return jsonify({'status': 'success', 'message': 'Zähler wird zurückgesetzt'}) def detect_objects_from_webcam(line_data, stream_id): cap = cv2.VideoCapture(0) # 0 for the default webcam if not cap.isOpened(): cap.release() raise RuntimeError('Webcam konnte nicht geöffnet werden') return generate_frames(cap, line_data, stream_id) @app.route('/webcam_feed') def webcam_feed(): line_data = get_line_from_session() stream_id = get_webcam_stream_id() try: generator = detect_objects_from_webcam(line_data, stream_id) except RuntimeError as exc: abort(503, description=str(exc)) return Response(generator, mimetype='multipart/x-mixed-replace; boundary=frame') @app.route('/upload', methods=['POST']) def upload_video(): if 'file' not in request.files: abort(400, description='Keine Datei erhalten') file = request.files['file'] if not file or file.filename == '': abort(400, description='Keine Datei ausgewählt') filename = secure_filename(file.filename) if not filename: abort(400, description='Ungültiger Dateiname') if not allowed_file(filename): abort(400, description='Ungültiger Dateityp') ensure_upload_dir() name, ext = os.path.splitext(filename) stored_filename = f"{name}_{uuid.uuid4().hex}{ext.lower()}" file_path = os.path.join(UPLOAD_DIR, stored_filename) file.save(file_path) return redirect(url_for('play_video', filename=stored_filename)) @app.route('/uploads/') def play_video(filename): safe_filename = os.path.basename(filename) if safe_filename != filename: abort(400, description='Ungültiger Dateiname') file_path = os.path.join(UPLOAD_DIR, safe_filename) if not os.path.isfile(file_path): abort(404) get_line_from_session() stream_id = get_video_stream_id(safe_filename) return render_template('play_video.html', filename=safe_filename, stream_id=stream_id) @app.route('/video/') def send_video(filename): return send_from_directory(UPLOAD_DIR, filename) def detect_objects_from_video(video_path, line_data, stream_id): cap = cv2.VideoCapture(video_path) if not cap.isOpened(): cap.release() raise RuntimeError('Video konnte nicht geöffnet werden') return generate_frames(cap, line_data, stream_id) @app.route('/video_feed/') def video_feed(filename): safe_filename = os.path.basename(filename) if safe_filename != filename: abort(400, description='Ungültiger Dateiname') video_path = os.path.join(UPLOAD_DIR, safe_filename) if not os.path.isfile(video_path): abort(404) line_data = get_line_from_session() stream_id = get_video_stream_id(safe_filename) try: generator = detect_objects_from_video(video_path, line_data, stream_id) except RuntimeError as exc: abort(503, description=str(exc)) return Response(generator, mimetype='multipart/x-mixed-replace; boundary=frame') if __name__ == '__main__': app.run('0.0.0.0',debug=False, port=8080)