Skip to content

fine_tune

training.fine_tune

OrionDataset

Bases: Dataset

Simply loads the raw JSON data. Processing happens in the Collator.

Source code in ground_segment/training/fine_tune.py
20
21
22
23
24
25
26
27
28
29
30
class OrionDataset(Dataset):
    """Simply loads the raw JSON data. Processing happens in the Collator."""

    def __init__(self, jsonl_file):
        self.data = load_dataset("json", data_files={"train": jsonl_file})["train"]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

VLMDataCollator

Processes an entire batch of images and texts perfectly for the VLM.

Source code in ground_segment/training/fine_tune.py
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
class VLMDataCollator:
    """Processes an entire batch of images and texts perfectly for the VLM."""

    def __init__(self, processor):
        self.processor = processor

    def __call__(self, batch):
        images = []
        texts = []

        for item in batch:
            images.append(
                [
                    Image.open(
                        f"/home/schopra/hdd/gaze/datasets/extras/{item['image']}"
                    ).convert("RGB")
                ]
            )

            prompt = item["conversations"][0]["content"]
            response = item["conversations"][1]["content"]
            messages = [
                {"role": "user", "content": prompt},
                {"role": "assistant", "content": response},
            ]
            texts.append(self.processor.apply_chat_template(messages, tokenize=False))

        # Process the entire batch at once!
        inputs = self.processor(
            images=images,
            text=texts,
            return_tensors="pt",
            padding=True,  # Just dynamic padding. No truncation, let the image fit
        )

        # Define labels for the loss calculation
        inputs["labels"] = inputs["input_ids"].clone()
        return inputs