Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[core] Improve getReactElementRef() utils #43022

Merged
merged 14 commits into from
Sep 19, 2024
4 changes: 2 additions & 2 deletions packages/mui-base/src/ClickAwayListener/ClickAwayListener.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import {
unstable_ownerDocument as ownerDocument,
unstable_useForkRef as useForkRef,
unstable_useEventCallback as useEventCallback,
unstable_getReactNodeRef as getReactNodeRef,
unstable_getReactElementRef as getReactElementRef,
} from '@mui/utils';

// TODO: return `EventHandlerName extends `on${infer EventName}` ? Lowercase<EventName> : never` once generatePropTypes runs with TS 4.1
Expand Down Expand Up @@ -95,7 +95,7 @@ function ClickAwayListener(props: ClickAwayListenerProps): React.JSX.Element {
};
}, []);

const handleRef = useForkRef(getReactNodeRef(children), nodeRef);
const handleRef = useForkRef(getReactElementRef(children), nodeRef);

// The handler doesn't take event.defaultPrevented into account:
//
Expand Down
4 changes: 2 additions & 2 deletions packages/mui-base/src/FocusTrap/FocusTrap.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import {
elementAcceptingRef,
unstable_useForkRef as useForkRef,
unstable_ownerDocument as ownerDocument,
unstable_getReactNodeRef as getReactNodeRef,
unstable_getReactElementRef as getReactElementRef,
} from '@mui/utils';
import { FocusTrapProps } from './FocusTrap.types';

Expand Down Expand Up @@ -153,7 +153,7 @@ function FocusTrap(props: FocusTrapProps): React.JSX.Element {
const activated = React.useRef(false);

const rootRef = React.useRef<HTMLElement>(null);
const handleRef = useForkRef(getReactNodeRef(children), rootRef);
const handleRef = useForkRef(getReactElementRef(children), rootRef);
const lastKeydown = React.useRef<KeyboardEvent | null>(null);

React.useEffect(() => {
Expand Down
8 changes: 6 additions & 2 deletions packages/mui-base/src/Portal/Portal.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import * as React from 'react';
import * as ReactDOM from 'react-dom';
import PropTypes from 'prop-types';
import getReactNodeRef from '@mui/utils/getReactNodeRef';
import getReactElementRef from '@mui/utils/getReactElementRef';
import {
exactProp,
HTMLElementType,
Expand Down Expand Up @@ -34,7 +34,11 @@ const Portal = React.forwardRef(function Portal(
) {
const { children, container, disablePortal = false } = props;
const [mountNode, setMountNode] = React.useState<ReturnType<typeof getContainer>>(null);
const handleRef = useForkRef(getReactNodeRef(children), forwardedRef);

const handleRef = useForkRef(
React.isValidElement(children) ? getReactElementRef(children) : null,
forwardedRef,
);

useEnhancedEffect(() => {
if (!disablePortal) {
Expand Down
4 changes: 2 additions & 2 deletions packages/mui-joy/src/Tooltip/Tooltip.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import {
unstable_useId as useId,
unstable_useTimeout as useTimeout,
unstable_Timeout as Timeout,
unstable_getReactNodeRef as getReactNodeRef,
unstable_getReactElementRef as getReactElementRef,
} from '@mui/utils';
import { Popper, unstable_composeClasses as composeClasses } from '@mui/base';
import { OverridableComponent } from '@mui/types';
Expand Down Expand Up @@ -416,7 +416,7 @@ const Tooltip = React.forwardRef(function Tooltip(inProps, ref) {
}, [handleClose, open]);

const handleUseRef = useForkRef(setChildNode, ref);
const handleRef = useForkRef(getReactNodeRef(children), handleUseRef);
const handleRef = useForkRef(getReactElementRef(children), handleUseRef);

// There is no point in displaying an empty tooltip.
if (typeof title !== 'number' && !title) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import {
unstable_useForkRef as useForkRef,
unstable_useEventCallback as useEventCallback,
} from '@mui/utils';
import getReactNodeRef from '@mui/utils/getReactNodeRef';
import getReactElementRef from '@mui/utils/getReactElementRef';

// TODO: return `EventHandlerName extends `on${infer EventName}` ? Lowercase<EventName> : never` once generatePropTypes runs with TS 4.1
function mapEventPropToEvent(
Expand Down Expand Up @@ -96,7 +96,7 @@ function ClickAwayListener(props: ClickAwayListenerProps): React.JSX.Element {
};
}, []);

const handleRef = useForkRef(getReactNodeRef(children), nodeRef);
const handleRef = useForkRef(getReactElementRef(children), nodeRef);

// The handler doesn't take event.defaultPrevented into account:
//
Expand Down
4 changes: 2 additions & 2 deletions packages/mui-material/src/Fade/Fade.js
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import * as React from 'react';
import PropTypes from 'prop-types';
import { Transition } from 'react-transition-group';
import elementAcceptingRef from '@mui/utils/elementAcceptingRef';
import getReactNodeRef from '@mui/utils/getReactNodeRef';
import getReactElementRef from '@mui/utils/getReactElementRef';
import { useTheme } from '../zero-styled';
import { reflow, getTransitionProps } from '../transitions/utils';
import useForkRef from '../utils/useForkRef';
Expand Down Expand Up @@ -49,7 +49,7 @@ const Fade = React.forwardRef(function Fade(props, ref) {

const enableStrictModeCompat = true;
const nodeRef = React.useRef(null);
const handleRef = useForkRef(nodeRef, getReactNodeRef(children), ref);
const handleRef = useForkRef(nodeRef, getReactElementRef(children), ref);

const normalizedTransitionCallback = (callback) => (maybeIsAppearing) => {
if (callback) {
Expand Down
4 changes: 2 additions & 2 deletions packages/mui-material/src/Grow/Grow.js
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import * as React from 'react';
import PropTypes from 'prop-types';
import useTimeout from '@mui/utils/useTimeout';
import elementAcceptingRef from '@mui/utils/elementAcceptingRef';
import getReactNodeRef from '@mui/utils/getReactNodeRef';
import getReactElementRef from '@mui/utils/getReactElementRef';
import { Transition } from 'react-transition-group';
import { useTheme } from '../zero-styled';
import { getTransitionProps, reflow } from '../transitions/utils';
Expand Down Expand Up @@ -62,7 +62,7 @@ const Grow = React.forwardRef(function Grow(props, ref) {
const theme = useTheme();

const nodeRef = React.useRef(null);
const handleRef = useForkRef(nodeRef, getReactNodeRef(children), ref);
const handleRef = useForkRef(nodeRef, getReactElementRef(children), ref);

const normalizedTransitionCallback = (callback) => (maybeIsAppearing) => {
if (callback) {
Expand Down
8 changes: 6 additions & 2 deletions packages/mui-material/src/Portal/Portal.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import {
unstable_useEnhancedEffect as useEnhancedEffect,
unstable_useForkRef as useForkRef,
unstable_setRef as setRef,
unstable_getReactNodeRef as getReactNodeRef,
unstable_getReactElementRef as getReactElementRef,
} from '@mui/utils';
import { PortalProps } from './Portal.types';

Expand All @@ -34,7 +34,11 @@ const Portal = React.forwardRef(function Portal(
) {
const { children, container, disablePortal = false } = props;
const [mountNode, setMountNode] = React.useState<ReturnType<typeof getContainer>>(null);
const handleRef = useForkRef(getReactNodeRef(children), forwardedRef);

const handleRef = useForkRef(
React.isValidElement(children) ? getReactElementRef(children) : null,
forwardedRef,
);

useEnhancedEffect(() => {
if (!disablePortal) {
Expand Down
4 changes: 2 additions & 2 deletions packages/mui-material/src/Select/Select.js
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import * as React from 'react';
import PropTypes from 'prop-types';
import clsx from 'clsx';
import deepmerge from '@mui/utils/deepmerge';
import getReactNodeRef from '@mui/utils/getReactNodeRef';
import getReactElementRef from '@mui/utils/getReactElementRef';
import SelectInput from './SelectInput';
import formControlState from '../FormControl/formControlState';
import useFormControl from '../FormControl/useFormControl';
Expand Down Expand Up @@ -86,7 +86,7 @@ const Select = React.forwardRef(function Select(inProps, ref) {
filled: <StyledFilledInput ownerState={ownerState} />,
}[variant];

const inputComponentRef = useForkRef(ref, getReactNodeRef(InputComponent));
const inputComponentRef = useForkRef(ref, getReactElementRef(InputComponent));

return (
<React.Fragment>
Expand Down
4 changes: 2 additions & 2 deletions packages/mui-material/src/Slide/Slide.js
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import { Transition } from 'react-transition-group';
import chainPropTypes from '@mui/utils/chainPropTypes';
import HTMLElementType from '@mui/utils/HTMLElementType';
import elementAcceptingRef from '@mui/utils/elementAcceptingRef';
import getReactNodeRef from '@mui/utils/getReactNodeRef';
import getReactElementRef from '@mui/utils/getReactElementRef';
import debounce from '../utils/debounce';
import useForkRef from '../utils/useForkRef';
import { useTheme } from '../zero-styled';
Expand Down Expand Up @@ -120,7 +120,7 @@ const Slide = React.forwardRef(function Slide(props, ref) {
} = props;

const childrenRef = React.useRef(null);
const handleRef = useForkRef(getReactNodeRef(children), childrenRef, ref);
const handleRef = useForkRef(getReactElementRef(children), childrenRef, ref);

const normalizedTransitionCallback = (callback) => (isAppearing) => {
if (callback) {
Expand Down
4 changes: 2 additions & 2 deletions packages/mui-material/src/Tooltip/Tooltip.js
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import { alpha } from '@mui/system/colorManipulator';
import { useRtl } from '@mui/system/RtlProvider';
import isFocusVisible from '@mui/utils/isFocusVisible';
import appendOwnerState from '@mui/utils/appendOwnerState';
import getReactNodeRef from '@mui/utils/getReactNodeRef';
import getReactElementRef from '@mui/utils/getReactElementRef';
import { styled, useTheme } from '../zero-styled';
import memoTheme from '../utils/memoTheme';
import { useDefaultProps } from '../DefaultPropsProvider';
Expand Down Expand Up @@ -555,7 +555,7 @@ const Tooltip = React.forwardRef(function Tooltip(inProps, ref) {
};
}, [handleClose, open]);

const handleRef = useForkRef(getReactNodeRef(children), setChildNode, ref);
const handleRef = useForkRef(getReactElementRef(children), setChildNode, ref);

// There is no point in displaying an empty tooltip.
// So we exclude all falsy values, except 0, which is valid.
Expand Down
4 changes: 2 additions & 2 deletions packages/mui-material/src/Unstable_TrapFocus/FocusTrap.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import {
elementAcceptingRef,
unstable_useForkRef as useForkRef,
unstable_ownerDocument as ownerDocument,
unstable_getReactNodeRef as getReactNodeRef,
unstable_getReactElementRef as getReactElementRef,
} from '@mui/utils';
import { FocusTrapProps } from './FocusTrap.types';

Expand Down Expand Up @@ -145,7 +145,7 @@ function FocusTrap(props: FocusTrapProps): React.JSX.Element {
const activated = React.useRef(false);

const rootRef = React.useRef<HTMLElement>(null);
const handleRef = useForkRef(getReactNodeRef(children), rootRef);
const handleRef = useForkRef(getReactElementRef(children), rootRef);
const lastKeydown = React.useRef<KeyboardEvent | null>(null);

React.useEffect(() => {
Expand Down
4 changes: 2 additions & 2 deletions packages/mui-material/src/Zoom/Zoom.js
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import * as React from 'react';
import PropTypes from 'prop-types';
import { Transition } from 'react-transition-group';
import elementAcceptingRef from '@mui/utils/elementAcceptingRef';
import getReactNodeRef from '@mui/utils/getReactNodeRef';
import getReactElementRef from '@mui/utils/getReactElementRef';
import { useTheme } from '../zero-styled';
import { reflow, getTransitionProps } from '../transitions/utils';
import useForkRef from '../utils/useForkRef';
Expand Down Expand Up @@ -49,7 +49,7 @@ const Zoom = React.forwardRef(function Zoom(props, ref) {
} = props;

const nodeRef = React.useRef(null);
const handleRef = useForkRef(nodeRef, getReactNodeRef(children), ref);
const handleRef = useForkRef(nodeRef, getReactElementRef(children), ref);

const normalizedTransitionCallback = (callback) => (maybeIsAppearing) => {
if (callback) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import getReactElementRef from '@mui/utils/getReactElementRef';
import * as React from 'react';

// @ts-expect-error
getReactElementRef(false);

// @ts-expect-error
getReactElementRef(null);

// @ts-expect-error
getReactElementRef(undefined);

// @ts-expect-error
getReactElementRef(1);

// @ts-expect-error
getReactElementRef([<div key="1" />, <div key="2" />]);

getReactElementRef(<div />);
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import { expect } from 'chai';
import getReactElementRef from '@mui/utils/getReactElementRef';
import * as React from 'react';

describe('getReactElementRef', () => {
it('should return undefined when not used correctly', () => {
aarongarciah marked this conversation as resolved.
Show resolved Hide resolved
// @ts-expect-error
expect(getReactElementRef(false)).to.equal(undefined);
// @ts-expect-error
expect(getReactElementRef()).to.equal(undefined);
// @ts-expect-error
expect(getReactElementRef(1)).to.equal(undefined);

const children = [<div key="1" />, <div key="2" />];
// @ts-expect-error
expect(getReactElementRef(children)).to.equal(undefined);
});

it('should return the ref of a React element', () => {
const ref = React.createRef<HTMLDivElement>();
const element = <div ref={ref} />;
expect(getReactElementRef(element)).to.equal(ref);
});

it('should return null for a fragment', () => {
const element = (
<React.Fragment>
<p>Hello</p>
<p>Hello</p>
</React.Fragment>
);
expect(getReactElementRef(element)).to.equal(null);
});

it('should return null for element with no ref', () => {
const element = <div />;
expect(getReactElementRef(element)).to.equal(null);
});
});
20 changes: 20 additions & 0 deletions packages/mui-utils/src/getReactElementRef/getReactElementRef.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import * as React from 'react';

/**
* Returns the ref of a React element handling differences between React 19 and older versions.
* It will throw runtime error if the element is not a valid React element.
*
* @param element React.ReactElement
* @returns React.Ref<any> | null | undefined
*/
export default function getReactElementRef(
element: React.ReactElement,
): React.Ref<any> | null | undefined {
// 'ref' is passed as prop in React 19, whereas 'ref' is directly attached to children in older versions
if (parseInt(React.version, 10) >= 19) {
return element.props?.ref;
}
// @ts-expect-error element.ref is not included in the ReactElement type
// https://github.com/DefinitelyTyped/DefinitelyTyped/discussions/70189
return element?.ref;
}
1 change: 1 addition & 0 deletions packages/mui-utils/src/getReactElementRef/index.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
export { default } from './getReactElementRef';
22 changes: 0 additions & 22 deletions packages/mui-utils/src/getReactNodeRef/getReactNodeRef.ts
aarongarciah marked this conversation as resolved.
Show resolved Hide resolved

This file was deleted.

1 change: 0 additions & 1 deletion packages/mui-utils/src/getReactNodeRef/index.ts

This file was deleted.

2 changes: 1 addition & 1 deletion packages/mui-utils/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -45,5 +45,5 @@ export { default as unstable_useSlotProps } from './useSlotProps';
export type { UseSlotPropsParameters, UseSlotPropsResult } from './useSlotProps';
export { default as unstable_resolveComponentProps } from './resolveComponentProps';
export { default as unstable_extractEventHandlers } from './extractEventHandlers';
export { default as unstable_getReactNodeRef } from './getReactNodeRef';
export { default as unstable_getReactElementRef } from './getReactElementRef';
export * from './types';
4 changes: 2 additions & 2 deletions packages/mui-utils/src/useForkRef/useForkRef.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import * as React from 'react';
import { expect } from 'chai';
import { createRenderer, screen } from '@mui/internal-test-utils';
import useForkRef from './useForkRef';
import getReactNodeRef from '../getReactNodeRef';
import getReactElementRef from '../getReactElementRef';

describe('useForkRef', () => {
const { render } = createRenderer();
Expand Down Expand Up @@ -48,7 +48,7 @@ describe('useForkRef', () => {
it('does nothing if none of the forked branches requires a ref', () => {
const Outer = React.forwardRef(function Outer(props, ref) {
const { children } = props;
const handleRef = useForkRef(getReactNodeRef(children), ref);
const handleRef = useForkRef(getReactElementRef(children), ref);

return React.cloneElement(children, { ref: handleRef });
});
Expand Down
Loading