Skip to content

validation

training.validation

extract_json(text)

Strictly extracts JSON, failing over to ERROR if the model disobeys formatting.

Source code in ground_segment/training/validation.py
17
18
19
20
21
22
23
24
25
26
27
def extract_json(text):
    """Strictly extracts JSON, failing over to ERROR if the model disobeys formatting."""
    try:
        start = text.find("{")
        end = text.rfind("}") + 1
        if start != -1 and end != 0:
            return json.loads(text[start:end])
    except json.JSONDecodeError:
        pass

    return {"category": "ERROR", "reason": f"Raw: {text.strip()[:50]}"}

print_confusion_matrix(truths, preds, condition_name)

Prints a per-class Recall and Precision breakdown, plus an aggregate total.

Source code in ground_segment/training/validation.py
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
def print_confusion_matrix(truths, preds, condition_name):
    """Prints a per-class Recall and Precision breakdown, plus an aggregate total."""
    print(f"\n--- {condition_name} ---")
    total_correct = 0
    total_samples = len(truths)

    for c in ["HIGH", "MEDIUM", "LOW"]:
        total_actual = truths.count(c)
        if total_actual == 0:
            continue

        total_predicted = preds.count(c)
        correct = sum(1 for t, p in zip(truths, preds) if t == c and p == c)
        total_correct += correct

        recall_pct = (correct / total_actual) * 100
        precision_pct = (
            (correct / total_predicted) * 100 if total_predicted > 0 else 0.0
        )

        print(
            f"{c:6s}: {correct:2d}/{total_actual:2d} ({recall_pct:.1f}% Recall) | Precision: {correct:2d}/{total_predicted:<2d} ({precision_pct:.1f}%)"
        )

    if total_samples > 0:
        print(
            f"TOTAL : {total_correct:2d}/{total_samples:2d} ({(total_correct / total_samples) * 100:.1f}% Overall Accuracy)"
        )

run_inference(model, processor, image, prompt)

Runs a single image+prompt through the VLM and returns parsed JSON + raw text.

Source code in ground_segment/training/validation.py
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
def run_inference(model, processor, image, prompt):
    """Runs a single image+prompt through the VLM and returns parsed JSON + raw text."""
    messages = [{"role": "user", "content": f"<image>\n{prompt}"}]
    text_input = processor.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )

    # Use the device of the model (handles both CUDA and MPS seamlessly)
    device = next(model.parameters()).device
    inputs = processor(images=[image], text=[text_input], return_tensors="pt").to(
        device, torch.float16
    )

    with torch.no_grad():
        output = model.generate(**inputs, max_new_tokens=200, do_sample=False)

    generated_text = processor.decode(
        output[0][inputs["input_ids"].shape[1] :], skip_special_tokens=True
    )

    # Return both the parsed dictionary and the raw string for debugging
    return extract_json(generated_text), generated_text