Skip to content

Commit

Permalink
ogr2ogr: optim: call GetArrowStream() only once on source layer when …
Browse files Browse the repository at this point in the history
…using Arrow interface

as this might be an expensive operation on some drivers.

Covered by existing tests.
  • Loading branch information
rouault committed Oct 13, 2024
1 parent fb97969 commit 9d058a9
Showing 1 changed file with 60 additions and 50 deletions.
110 changes: 60 additions & 50 deletions apps/ogr2ogr_lib.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,7 @@ struct TargetLayerInfo
const char *m_pszGeomField = nullptr;
std::vector<int> m_anDateTimeFieldIdx{};
bool m_bSupportCurves = false;
OGRArrowArrayStream m_sArrowArrayStream{};
};

struct AssociatedLayers
Expand All @@ -507,7 +508,8 @@ class SetupTargetLayer
bool CanUseWriteArrowBatch(OGRLayer *poSrcLayer, OGRLayer *poDstLayer,
bool bJustCreatedLayer,
const GDALVectorTranslateOptions *psOptions,
bool &bError);
bool bPreserveFID, bool &bError,
OGRArrowArrayStream &streamSrc);

public:
GDALDataset *m_poSrcDS = nullptr;
Expand Down Expand Up @@ -3990,13 +3992,46 @@ static int GetArrowGeomFieldIndex(const struct ArrowSchema *psLayerSchema,
return -1;
}

/************************************************************************/
/* BuildGetArrowStreamOptions() */
/************************************************************************/

static CPLStringList
BuildGetArrowStreamOptions(const GDALVectorTranslateOptions *psOptions,
bool bPreserveFID)
{
CPLStringList aosOptionsGetArrowStream;
aosOptionsGetArrowStream.SetNameValue("SILENCE_GET_SCHEMA_ERROR", "YES");
aosOptionsGetArrowStream.SetNameValue("GEOMETRY_ENCODING", "WKB");
if (!bPreserveFID)
aosOptionsGetArrowStream.SetNameValue("INCLUDE_FID", "NO");
if (psOptions->nLimit >= 0)
{
aosOptionsGetArrowStream.SetNameValue(
"MAX_FEATURES_IN_BATCH",
CPLSPrintf(CPL_FRMT_GIB,
std::min<GIntBig>(psOptions->nLimit,
(psOptions->nGroupTransactions > 0
? psOptions->nGroupTransactions
: 65536))));
}
else if (psOptions->nGroupTransactions > 0)
{
aosOptionsGetArrowStream.SetNameValue(
"MAX_FEATURES_IN_BATCH",
CPLSPrintf("%d", psOptions->nGroupTransactions));
}
return aosOptionsGetArrowStream;
}

/************************************************************************/
/* SetupTargetLayer::CanUseWriteArrowBatch() */
/************************************************************************/

bool SetupTargetLayer::CanUseWriteArrowBatch(
OGRLayer *poSrcLayer, OGRLayer *poDstLayer, bool bJustCreatedLayer,
const GDALVectorTranslateOptions *psOptions, bool &bError)
const GDALVectorTranslateOptions *psOptions, bool bPreserveFID,
bool &bError, OGRArrowArrayStream &streamSrc)
{
bError = false;

Expand Down Expand Up @@ -4050,20 +4085,20 @@ bool SetupTargetLayer::CanUseWriteArrowBatch(
}
}

struct ArrowArrayStream streamSrc;
const char *const apszOptions[] = {"SILENCE_GET_SCHEMA_ERROR=YES",
nullptr};
if (poSrcLayer->GetArrowStream(&streamSrc, apszOptions))
const CPLStringList aosGetArrowStreamOptions(
BuildGetArrowStreamOptions(psOptions, bPreserveFID));
if (poSrcLayer->GetArrowStream(streamSrc.get(),
aosGetArrowStreamOptions.List()))
{
struct ArrowSchema schemaSrc;
if (streamSrc.get_schema(&streamSrc, &schemaSrc) == 0)
if (streamSrc.get_schema(&schemaSrc) == 0)
{
if (psOptions->bTransform &&
GetArrowGeomFieldIndex(&schemaSrc,
poSrcLayer->GetGeometryColumn()) < 0)
{
schemaSrc.release(&schemaSrc);
streamSrc.release(&streamSrc);
streamSrc.clear();
return false;
}

Expand Down Expand Up @@ -4145,7 +4180,7 @@ bool SetupTargetLayer::CanUseWriteArrowBatch(
"Cannot create field %s",
pszFieldName);
schemaSrc.release(&schemaSrc);
streamSrc.release(&streamSrc);
streamSrc.clear();
return false;
}
}
Expand All @@ -4157,7 +4192,8 @@ bool SetupTargetLayer::CanUseWriteArrowBatch(
// check that it looks to be the same as the source
// one
struct ArrowArrayStream streamDst;
if (poDstLayer->GetArrowStream(&streamDst, nullptr))
if (poDstLayer->GetArrowStream(
&streamDst, aosGetArrowStreamOptions.List()))
{
struct ArrowSchema schemaDst;
if (streamDst.get_schema(&streamDst, &schemaDst) ==
Expand Down Expand Up @@ -4188,7 +4224,8 @@ bool SetupTargetLayer::CanUseWriteArrowBatch(
}
schemaSrc.release(&schemaSrc);
}
streamSrc.release(&streamSrc);
if (!bUseWriteArrowBatch)
streamSrc.clear();
}
}
return bUseWriteArrowBatch;
Expand Down Expand Up @@ -4915,8 +4952,10 @@ SetupTargetLayer::Setup(OGRLayer *poSrcLayer, const char *pszNewLayerName,
}

bool bError = false;
const bool bUseWriteArrowBatch = CanUseWriteArrowBatch(
poSrcLayer, poDstLayer, bJustCreatedLayer, psOptions, bError);
OGRArrowArrayStream streamSrc;
const bool bUseWriteArrowBatch =
CanUseWriteArrowBatch(poSrcLayer, poDstLayer, bJustCreatedLayer,
psOptions, bPreserveFID, bError, streamSrc);
if (bError)
return nullptr;

Expand Down Expand Up @@ -5378,7 +5417,7 @@ SetupTargetLayer::Setup(OGRLayer *poSrcLayer, const char *pszNewLayerName,
nTotalEventsDone = 0;
}

std::unique_ptr<TargetLayerInfo> psInfo(new TargetLayerInfo);
auto psInfo = std::make_unique<TargetLayerInfo>();
psInfo->m_bUseWriteArrowBatch = bUseWriteArrowBatch;
psInfo->m_nFeaturesRead = 0;
psInfo->m_bPerFeatureCT = false;
Expand Down Expand Up @@ -5475,6 +5514,8 @@ SetupTargetLayer::Setup(OGRLayer *poSrcLayer, const char *pszNewLayerName,
psInfo->m_bSupportCurves =
CPL_TO_BOOL(poDstLayer->TestCapability(OLCCurveGeometries));

psInfo->m_sArrowArrayStream = std::move(streamSrc);

return psInfo;
}

Expand Down Expand Up @@ -5769,49 +5810,19 @@ bool LayerTranslator::TranslateArrow(
GIntBig *pnReadFeatureCount, GDALProgressFunc pfnProgress,
void *pProgressArg, const GDALVectorTranslateOptions *psOptions)
{
struct ArrowArrayStream stream;
struct ArrowSchema schema;
CPLStringList aosOptionsGetArrowStream;
CPLStringList aosOptionsWriteArrowBatch;
aosOptionsGetArrowStream.SetNameValue("GEOMETRY_ENCODING", "WKB");
if (!psInfo->m_bPreserveFID)
aosOptionsGetArrowStream.SetNameValue("INCLUDE_FID", "NO");
else
if (psInfo->m_bPreserveFID)
{
aosOptionsWriteArrowBatch.SetNameValue(
"FID", psInfo->m_poSrcLayer->GetFIDColumn());
aosOptionsWriteArrowBatch.SetNameValue("IF_FID_NOT_PRESERVED",
"WARNING");
}
if (psOptions->nLimit >= 0)
{
aosOptionsGetArrowStream.SetNameValue(
"MAX_FEATURES_IN_BATCH",
CPLSPrintf(CPL_FRMT_GIB,
std::min<GIntBig>(psOptions->nLimit,
(psOptions->nGroupTransactions > 0
? psOptions->nGroupTransactions
: 65536))));
}
else if (psOptions->nGroupTransactions > 0)
{
aosOptionsGetArrowStream.SetNameValue(
"MAX_FEATURES_IN_BATCH",
CPLSPrintf("%d", psOptions->nGroupTransactions));
}
if (psInfo->m_poSrcLayer->GetArrowStream(&stream,
aosOptionsGetArrowStream.List()))
{
if (stream.get_schema(&stream, &schema) != 0)
{
CPLError(CE_Failure, CPLE_AppDefined, "stream.get_schema() failed");
stream.release(&stream);
return false;
}
}
else

if (psInfo->m_sArrowArrayStream.get_schema(&schema) != 0)
{
CPLError(CE_Failure, CPLE_AppDefined, "GetArrowStream() failed");
CPLError(CE_Failure, CPLE_AppDefined, "stream.get_schema() failed");
return false;
}

Expand Down Expand Up @@ -5865,7 +5876,7 @@ bool LayerTranslator::TranslateArrow(
{
struct ArrowArray array;
// Acquire source batch
if (stream.get_next(&stream, &array) != 0)
if (psInfo->m_sArrowArrayStream.get_next(&array) != 0)
{
CPLError(CE_Failure, CPLE_AppDefined, "stream.get_next() failed");
bRet = false;
Expand Down Expand Up @@ -6043,7 +6054,6 @@ bool LayerTranslator::TranslateArrow(

schema.release(&schema);

stream.release(&stream);
return bRet;
}

Expand Down

0 comments on commit 9d058a9

Please sign in to comment.