File size: 5,675 Bytes
5e9fb2f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
# Copyright 2026 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
import shutil
import subprocess
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Union

from . import constants
from .utils import get_token


if TYPE_CHECKING:
    import duckdb


@dataclass(frozen=True)
class DatasetParquetEntry:
    """Represents a single parquet file available for a dataset on the Hub."""

    config: str
    split: str
    url: str
    size: int


def execute_raw_sql_query(sql_query: str, *, token: str | bool | None = None) -> list[dict[str, Any]]:
    normalized_query = sql_query.strip().rstrip(";").strip()
    _raise_on_forbidden_query(normalized_query)

    connection = None
    try:
        connection = _get_duckdb_connection(token=token)
        relation = connection.sql(normalized_query)
        if relation is None:
            raise ValueError("SQL query must return rows.")

        if isinstance(relation, _DuckDBCliRelation):
            # DuckDB binary => run CLI => parse JSON
            return relation.execute()
        else:
            # DuckDB Python API => fetch columns + rows => convert to dicts
            columns = tuple(column[0] for column in relation.description)
            rows = tuple(tuple(row) for row in relation.fetchall())
            return [dict(zip(columns, row)) for row in rows]
    finally:
        if connection is not None:
            connection.close()


def _raise_on_forbidden_query(query: str) -> None:
    if len(query) == 0:
        raise ValueError("SQL query cannot be empty.")

    # DuckDB CLI meta-commands are dot-prefixed words (e.g. `.shell`, `.output`).
    # Let's forbid them for now but allow SQL expressions like `.5` that can legitimately start a line.
    for line in query.splitlines():
        stripped = line.lstrip()
        if stripped.startswith(".") and stripped[1:2].isalpha():
            raise ValueError("DuckDB CLI meta-commands are not allowed in SQL queries.")


def _get_duckdb_connection(
    token: str | bool | None,
) -> Union["duckdb.DuckDBPyConnection", "_DuckDBCliConnection"]:
    try:
        # If DuckDB is installed as a Python package, use it!
        import duckdb
    except ImportError as error:
        # Otherwise, use the DuckDB CLI binary.
        duckdb_binary = shutil.which("duckdb")
        if duckdb_binary is None:
            raise ImportError(
                "DuckDB is required for `hf datasets sql`. Install the Python package with `pip install duckdb` or "
                "install the DuckDB CLI binary (for example `brew install duckdb`)."
            ) from error
        return _DuckDBCliConnection(binary_path=duckdb_binary, token=token)

    # Create a new connection (Python API).
    connection = duckdb.connect()
    try:
        for statement in _build_duckdb_secret_statements(token):
            connection.execute(statement)
        return connection
    except Exception:
        connection.close()
        raise


@dataclass
class _DuckDBCliConnection:
    """DuckDB connection.

    Mimics the DuckDB Python API, but runs the queries via the DuckDB CLI binary.
    """

    binary_path: str
    token: str | bool | None

    def __post_init__(self) -> None:
        self._setup_statements = _build_duckdb_secret_statements(self.token)

    def sql(self, query: str) -> "_DuckDBCliRelation":
        return _DuckDBCliRelation(binary_path=self.binary_path, setup_statements=self._setup_statements, query=query)

    def close(self) -> None:
        pass


@dataclass
class _DuckDBCliRelation:
    """DuckDB relation.

    Mimics the DuckDB Python API, but runs the queries via the DuckDB CLI binary.
    """

    binary_path: str
    setup_statements: list[str]
    query: str

    def execute(self) -> list[dict[str, Any]]:
        # Build the DuckDB CLI input.
        setup = []
        if self.setup_statements:
            setup = [
                f".output {os.devnull}",
                *(f"{stmt};" for stmt in self.setup_statements),
                ".output",
            ]
        full_query = "\n".join(setup + [self.query + ";"])

        # Run DuckDB binary
        result = subprocess.run(
            [self.binary_path, "-json"],
            input=full_query,
            capture_output=True,
            text=True,
            check=False,
        )
        if result.returncode != 0:
            error_message = result.stderr.strip() or result.stdout.strip() or "DuckDB CLI command failed."
            raise RuntimeError(error_message)

        # Parse JSON output and return
        return json.loads(result.stdout.strip())


def _build_duckdb_secret_statements(token: str | bool | None) -> list[str]:
    if token is None or token is True:
        token = get_token()

    if not token:
        return []

    escaped_token = token.replace("'", "''")
    escaped_endpoint = constants.ENDPOINT.replace("'", "''")
    return [
        f"CREATE OR REPLACE SECRET hf_hub_token (TYPE HTTP, BEARER_TOKEN '{escaped_token}', SCOPE '{escaped_endpoint}')",
        f"CREATE OR REPLACE SECRET hf_token (TYPE HUGGINGFACE, TOKEN '{escaped_token}')",
    ]