Skip to content

Commit

Permalink
consistently pass around ref_prefix and protocol_version in dulwich.c…
Browse files Browse the repository at this point in the history
…lient (jelmer#1421)
  • Loading branch information
jelmer authored Nov 5, 2024
2 parents e1c813c + ba2530e commit a90b9d7
Show file tree
Hide file tree
Showing 5 changed files with 117 additions and 26 deletions.
119 changes: 94 additions & 25 deletions dulwich/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@
extract_capability_names,
parse_capability,
pkt_line,
pkt_seq,
)
from .refs import (
PEELED_TAG_SUFFIX,
Expand All @@ -128,6 +129,12 @@
)
from .repo import Repo

# Default ref prefix, used if none is specified.
# GitHub defaults to just sending HEAD if no ref-prefix is
# specified, so explicitly request all refs to match
# behaviour with v1 when no ref-prefix is specified.
DEFAULT_REF_PREFIX = [b"HEAD", b"refs/"]

ObjectID = bytes


Expand Down Expand Up @@ -1037,7 +1044,12 @@ def fetch_pack(
"""
raise NotImplementedError(self.fetch_pack)

def get_refs(self, path):
def get_refs(
self,
path,
protocol_version: Optional[int] = None,
ref_prefix: Optional[list[Ref]] = None,
):
"""Retrieve the current refs from a git smart server.
Args:
Expand Down Expand Up @@ -1187,7 +1199,12 @@ def __init__(self, path_encoding=DEFAULT_ENCODING, **kwargs) -> None:
self._remote_path_encoding = path_encoding
super().__init__(**kwargs)

async def _connect(self, cmd, path, protocol_version=None):
def _connect(
self,
cmd: bytes,
path: Union[str, bytes],
protocol_version: Optional[int] = None,
) -> tuple[Protocol, Callable[[], bool], Optional[IO[bytes]]]:
"""Create a connection to the server.
This method is abstract - concrete implementations should
Expand Down Expand Up @@ -1375,10 +1392,7 @@ def fetch_pack(
proto.write_pkt_line(b"symrefs")
proto.write_pkt_line(b"peel")
if ref_prefix is None:
# GitHub defaults to just sending HEAD if no ref-prefix is
# specified, so explicitly request all refs to match
# behaviour with v1 when no ref-prefix is specified.
ref_prefix = [b"HEAD", b"refs/"]
ref_prefix = DEFAULT_REF_PREFIX
for prefix in ref_prefix:
proto.write_pkt_line(b"ref-prefix " + prefix)
proto.write_pkt_line(None)
Expand Down Expand Up @@ -1434,7 +1448,12 @@ def fetch_pack(
)
return FetchPackResult(refs, symrefs, agent, new_shallow, new_unshallow)

def get_refs(self, path, protocol_version=None):
def get_refs(
self,
path,
protocol_version: Optional[int] = None,
ref_prefix: Optional[list[Ref]] = None,
):
"""Retrieve the current refs from a git smart server."""
# stock `git ls-remote` uses upload-pack
if (
Expand All @@ -1460,6 +1479,10 @@ def get_refs(self, path, protocol_version=None):
proto.write(b"0001") # delim-pkt
proto.write_pkt_line(b"symrefs")
proto.write_pkt_line(b"peel")
if ref_prefix is None:
ref_prefix = DEFAULT_REF_PREFIX
for prefix in ref_prefix:
proto.write_pkt_line(b"ref-prefix " + prefix)
proto.write_pkt_line(None)
with proto:
try:
Expand Down Expand Up @@ -1548,7 +1571,12 @@ def get_url(self, path):
netloc += ":%d" % self._port
return urlunsplit(("git", netloc, path, "", ""))

def _connect(self, cmd, path, protocol_version=None):
def _connect(
self,
cmd: bytes,
path: Union[str, bytes],
protocol_version: Optional[int] = None,
) -> tuple[Protocol, Callable[[], bool], Optional[IO[bytes]]]:
if not isinstance(cmd, bytes):
raise TypeError(cmd)
if not isinstance(path, bytes):
Expand All @@ -1558,8 +1586,8 @@ def _connect(self, cmd, path, protocol_version=None):
)
s = None
err = OSError(f"no address found for {self._host}")
for family, socktype, proto, canonname, sockaddr in sockaddrs:
s = socket.socket(family, socktype, proto)
for family, socktype, protof, canonname, sockaddr in sockaddrs:
s = socket.socket(family, socktype, protof)
s.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
try:
s.connect(sockaddr)
Expand Down Expand Up @@ -1668,7 +1696,12 @@ def from_parsedurl(cls, parsedurl, **kwargs):

git_command = None

def _connect(self, service, path, protocol_version=None):
def _connect(
self,
service: bytes,
path: Union[bytes, str],
protocol_version: Optional[int] = None,
) -> tuple[Protocol, Callable[[], bool], Optional[IO[bytes]]]:
if not isinstance(service, bytes):
raise TypeError(service)
if isinstance(path, bytes):
Expand Down Expand Up @@ -1890,7 +1923,12 @@ def fetch_pack(
)
return FetchPackResult(r.get_refs(), symrefs, agent)

def get_refs(self, path):
def get_refs(
self,
path,
protocol_version: Optional[int] = None,
ref_prefix: Optional[list[Ref]] = None,
):
"""Retrieve the current refs from a local on-disk repository."""
with self._open_repo(path) as target:
return target.get_refs()
Expand Down Expand Up @@ -1952,7 +1990,7 @@ def run_command(
password=None,
key_filename=None,
ssh_command=None,
protocol_version=None,
protocol_version: Optional[int] = None,
):
if password is not None:
raise NotImplementedError(
Expand Down Expand Up @@ -2127,7 +2165,12 @@ def _get_cmd_path(self, cmd):
assert isinstance(cmd, bytes)
return cmd

def _connect(self, cmd, path, protocol_version=None):
def _connect(
self,
cmd: bytes,
path: Union[str, bytes],
protocol_version: Optional[int] = None,
) -> tuple[Protocol, Callable[[], bool], Optional[IO[bytes]]]:
if not isinstance(cmd, bytes):
raise TypeError(cmd)
if isinstance(path, bytes):
Expand Down Expand Up @@ -2361,7 +2404,11 @@ def _http_request(self, url, headers=None, data=None):
raise NotImplementedError(self._http_request)

def _discover_references(
self, service, base_url, protocol_version=None
self,
service,
base_url,
protocol_version: Optional[int] = None,
ref_prefix: Optional[list[Ref]] = None,
) -> tuple[
dict[Ref, ObjectID], set[bytes], str, dict[Ref, Ref], dict[Ref, ObjectID]
]:
Expand Down Expand Up @@ -2407,15 +2454,24 @@ def _discover_references(
if not self.dumb:

def begin_protocol_v2(proto):
nonlocal ref_prefix
server_capabilities = read_server_capabilities(proto.read_pkt_seq())
if ref_prefix is None:
ref_prefix = DEFAULT_REF_PREFIX

pkts = [
b"symrefs",
b"peel",
]
for prefix in ref_prefix:
pkts.append(b"ref-prefix " + prefix)

body = b"".join(
[pkt_line(b"command=ls-refs\n"), b"0001", pkt_seq(*pkts)]
)

resp, read = self._smart_request(
service.decode("ascii"),
base_url,
pkt_line(b"command=ls-refs\n")
+ b"0001"
+ pkt_line(b"symrefs")
+ pkt_line(b"peel")
+ b"0000",
service.decode("ascii"), base_url, body
)
proto = Protocol(read, None)
return server_capabilities, resp, read, proto
Expand Down Expand Up @@ -2613,7 +2669,10 @@ def fetch_pack(
"""
url = self._get_url(path)
refs, server_capabilities, url, symrefs, peeled = self._discover_references(
b"git-upload-pack", url, protocol_version
b"git-upload-pack",
url,
protocol_version=protocol_version,
ref_prefix=ref_prefix,
)
(
negotiated_capabilities,
Expand Down Expand Up @@ -2678,10 +2737,20 @@ def fetch_pack(
finally:
resp.close()

def get_refs(self, path):
def get_refs(
self,
path,
protocol_version: Optional[int] = None,
ref_prefix: Optional[list[Ref]] = None,
):
"""Retrieve the current refs from a git smart server."""
url = self._get_url(path)
refs, _, _, _, peeled = self._discover_references(b"git-upload-pack", url)
refs, _, _, _, peeled = self._discover_references(
b"git-upload-pack",
url,
protocol_version=protocol_version,
ref_prefix=ref_prefix,
)
for refname, refvalue in peeled.items():
refs[refname + PEELED_TAG_SUFFIX] = refvalue
return refs
Expand Down
1 change: 1 addition & 0 deletions dulwich/porcelain.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,6 +587,7 @@ def clone(
depth=depth,
filter_spec=filter_spec,
protocol_version=protocol_version,
**kwargs,
)


Expand Down
9 changes: 9 additions & 0 deletions dulwich/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,15 @@ def pkt_line(data):
return ("%04x" % (len(data) + 4)).encode("ascii") + data


def pkt_seq(*seq):
"""Wrap a sequence of data in pkt-lines.
Args:
seq: An iterable of strings to wrap.
"""
return b"".join([pkt_line(s) for s in seq]) + pkt_line(None)


class Protocol:
"""Class for interacting with a remote git process over the wire.
Expand Down
2 changes: 1 addition & 1 deletion dulwich/repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -1225,7 +1225,7 @@ def _write_reflog(
pass
if committer is None:
config = self.get_config_stack()
committer = self._get_user_identity(config)
committer = get_user_identity(config)
check_user_identity(committer)
if timestamp is None:
timestamp = int(time.time())
Expand Down
12 changes: 12 additions & 0 deletions tests/test_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,23 @@
ack_type,
extract_capabilities,
extract_want_line_capabilities,
pkt_line,
pkt_seq,
)

from . import TestCase


class PktLinetests:
def test_pkt_line(self):
self.assertEqual(b"0007bla", pkt_line(b"bla"))
self.assertEqual(b"0000", pkt_line(None))

def test_pkt_seq(self):
self.assertEqual(b"0007bla0003foo0000", pkt_seq([b"bla", b"foo"]))
self.assertEqual(b"0000", pkt_seq([]))


class BaseProtocolTests:
def test_write_pkt_line_none(self):
self.proto.write_pkt_line(None)
Expand Down

0 comments on commit a90b9d7

Please sign in to comment.